diff --git a/.clang-format b/.clang-format index 3b267840004..c6488cb3585 100644 --- a/.clang-format +++ b/.clang-format @@ -94,7 +94,7 @@ PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left +PointerAlignment: Right RawStringFormats: - Language: Cpp Delimiters: diff --git a/mindspore/ccsrc/common/utils.cc b/mindspore/ccsrc/common/utils.cc index 328a0591131..7109c121e52 100644 --- a/mindspore/ccsrc/common/utils.cc +++ b/mindspore/ccsrc/common/utils.cc @@ -23,7 +23,7 @@ namespace common { const int CACHED_STR_NUM = 1 << 8; const int CACHED_STR_MASK = CACHED_STR_NUM - 1; std::vector STR_HOLDER(CACHED_STR_NUM); -const char* SafeCStr(const std::string&& str) { +const char *SafeCStr(const std::string &&str) { static std::atomic index{0}; uint32_t cur_index = index++; cur_index = cur_index & CACHED_STR_MASK; diff --git a/mindspore/ccsrc/common/utils.h b/mindspore/ccsrc/common/utils.h index 7cee933ac8c..8f6e8f7c0c4 100644 --- a/mindspore/ccsrc/common/utils.h +++ b/mindspore/ccsrc/common/utils.h @@ -21,16 +21,16 @@ #include #define DISABLE_COPY_AND_ASSIGN(ClassType) \ - ClassType(const ClassType&) = delete; \ - ClassType& operator=(const ClassType&) = delete; + ClassType(const ClassType &) = delete; \ + ClassType &operator=(const ClassType &) = delete; namespace mindspore { namespace common { -inline const char* SafeCStr(const std::string& str) { return str.c_str(); } -const char* SafeCStr(const std::string&& str); +inline const char *SafeCStr(const std::string &str) { return str.c_str(); } +const char *SafeCStr(const std::string &&str); -static inline std::string GetEnv(const std::string& envvar) { - const char* value = ::getenv(envvar.c_str()); +static inline std::string GetEnv(const std::string &envvar) { + const char *value = ::getenv(envvar.c_str()); if (value == nullptr) { return std::string(); diff --git a/mindspore/ccsrc/dataset/kernels/image/decode_op.h b/mindspore/ccsrc/dataset/kernels/image/decode_op.h index 50d2d3cb680..6e7180958a3 100644 --- a/mindspore/ccsrc/dataset/kernels/image/decode_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/decode_op.h @@ -34,11 +34,11 @@ class DecodeOp : public TensorOp { ~DecodeOp() = default; - Status Compute(const std::shared_ptr& input, std::shared_ptr* output) override; + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - void Print(std::ostream& out) const override { out << "DecodeOp"; } - Status OutputShape(const std::vector& inputs, std::vector& outputs) override; - Status OutputType(const std::vector& inputs, std::vector& outputs) override; + void Print(std::ostream &out) const override { out << "DecodeOp"; } + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; private: bool is_rgb_format_ = true; diff --git a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc index e7a8cc34962..a28f2bb6fd4 100644 --- a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc @@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int rnd_.seed(seed_); } -Status DistortBoundingBoxCropOp::Compute(const std::vector>& input, - std::vector>* output) { +Status DistortBoundingBoxCropOp::Compute(const std::vector> &input, + std::vector> *output) { IO_CHECK_VECTOR(input, output); if (input.size() != NumInput()) return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); @@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector& inputs, - std::vector& outputs) { +Status DistortBoundingBoxCropOp::OutputShape(const std::vector &inputs, + std::vector &outputs) { RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); outputs.clear(); TensorShape out = TensorShape{-1, -1}; @@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector& inp if (!outputs.empty()) return Status::OK(); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); } -Status DistortBoundingBoxCropOp::OutputType(const std::vector& inputs, std::vector& outputs) { +Status DistortBoundingBoxCropOp::OutputType(const std::vector &inputs, std::vector &outputs) { RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); outputs[0] = inputs[0]; return Status::OK(); diff --git a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h index 6d5dca99fb9..749c166d594 100644 --- a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h @@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp { ~DistortBoundingBoxCropOp() override = default; - void Print(std::ostream& out) const override { + void Print(std::ostream &out) const override { out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; } - Status Compute(const std::vector>& input, - std::vector>* output) override; + Status Compute(const std::vector> &input, + std::vector> *output) override; uint32_t NumInput() override { return 5; } - Status OutputShape(const std::vector& inputs, std::vector& outputs) override; - Status OutputType(const std::vector& inputs, std::vector& outputs) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; private: int32_t max_attempts_; diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc index 3cf60656595..a3cf8cefb50 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc @@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ rnd_.seed(GetSeed()); } -Status RandomCropAndResizeOp::Compute(const std::shared_ptr& input, std::shared_ptr* output) { +Status RandomCropAndResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); @@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr& input, std: (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); } -Status RandomCropAndResizeOp::OutputShape(const std::vector& inputs, std::vector& outputs) { +Status RandomCropAndResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); outputs.clear(); TensorShape out = TensorShape{target_height_, target_width_}; @@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector& inputs if (!outputs.empty()) return Status::OK(); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); } -Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) { +Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { double scale, aspect; *crop_width = w_in; *crop_height = h_in; diff --git a/mindspore/ccsrc/debug/anf_ir_dump.h b/mindspore/ccsrc/debug/anf_ir_dump.h index 5c4bc5eacd6..a53888348d0 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.h +++ b/mindspore/ccsrc/debug/anf_ir_dump.h @@ -22,7 +22,7 @@ namespace mindspore { constexpr char PARALLEL_STRATEGY[] = "strategy"; -void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false); +void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false); } // namespace mindspore diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 8e626d6f9a7..6ebe3ad43f1 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF; // get MindSpore Intermediate Representation Path std::string GetMsIrPath(void) { std::string path; - const char* path_ptr = getenv("MS_IR_PATH"); + const char *path_ptr = getenv("MS_IR_PATH"); if (path_ptr != nullptr) { path = path_ptr; char real_path[PATH_MAX] = {0}; @@ -62,13 +62,13 @@ std::string GetMsIrPath(void) { return path; } -std::string dump_obj(const py::object& obj, const std::string& path) { +std::string dump_obj(const py::object &obj, const std::string &path) { py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); return py::str(name); } -py::object load_obj(const std::string& path) { +py::object load_obj(const std::string &path) { py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); return obj; @@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) { // ============================================= MindSpore IR Exporter ============================================= -std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { +std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) { abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast(nd->Shape()); TypePtr type = dyn_cast(nd->Type()); std::ostringstream oss; @@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { return oss.str(); } -std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const { +std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const { std::string pkl_path = GetMsIrPath(); // if not specified env 'MS_IR_PATH', do not create any files if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { @@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca return file_prefix + file_name; } -int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) { +int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) { if (func_graph == nullptr || param == nullptr) { return -1; } @@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& // try to find index of parameter for SymbolicKeyInstance from all exported graphs // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different -int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { +int AnfExporter::GetParamIndexFromExported(const AnfNodePtr ¶m) { if (param == nullptr) { return -1; } int ret = -1; - for (const auto& item : exported) { + for (const auto &item : exported) { auto pram_iter = item.second.find(param); if (pram_iter != item.second.end()) { return pram_iter->second; @@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { return ret; } -std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) { +std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(node); return GetValueText(fg, node->value()); } -std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) { +std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) { auto py_funcs = mt_func_graph->GetPyFunctions(); if (py_funcs.empty()) { return ""; @@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap oss << "{"; bool is_first = true; - for (const auto& py_func : py_funcs) { + for (const auto &py_func : py_funcs) { if (is_first) { is_first = false; } else { @@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap * ├── GradOperation * └── TupleAdd */ -std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) { +std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) { if (meta_func_graph == nullptr) { return ""; } @@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_ return oss.str(); } -std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { +std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) { std::ostringstream oss; if (prim == nullptr) { return oss.str(); @@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { if (prim->isa()) { auto do_signature = dyn_cast(prim); - auto& func = do_signature->function(); + auto &func = do_signature->function(); if (func->isa()) { auto sig_prim = dyn_cast(func); oss << sig_prim->GetAttrsText(); @@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { return oss.str(); } -std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { +std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) { std::ostringstream oss; if (ns == nullptr) { return oss.str(); @@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { return oss.str(); } -std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, - const SymbolicKeyInstancePtr& sym_inst) { +std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, + const SymbolicKeyInstancePtr &sym_inst) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(sym_inst); AnfNodePtr sym_node = sym_inst->node(); @@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra return oss.str(); } -std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; // output ValueList, ValueTuple ValueSequeuePtr seq = dyn_cast(value); @@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V return oss.str(); } -std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; ValueDictionaryPtr dict = value->cast(); oss << "{"; bool first_flag = true; - for (const auto& elem : dict->value()) { + for (const auto &elem : dict->value()) { if (first_flag) { first_flag = false; } else { @@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value return oss.str(); } -std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { +std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) { std::ostringstream oss; if (check_integrity_) { @@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& return oss.str(); } -std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; bool is_null_ptr = (func_graph == nullptr || value == nullptr); if (is_null_ptr) { @@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu } // this function is used to output node in CNode's inputs -std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map) { +std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map) { std::ostringstream oss; if (func_graph == nullptr || node == nullptr) { return oss.str(); @@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An return oss.str(); } -void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector& parameters, - OrderedMap* param_map) { +void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector ¶meters, + OrderedMap *param_map) { bool first_flag = true; - for (const AnfNodePtr& param : parameters) { + for (const AnfNodePtr ¶m : parameters) { if (first_flag) { first_flag = false; ofs << " "; @@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vectorinputs(); + auto &inputs = node->inputs(); if (inputs.size() > 1) { ofs << " #("; for (size_t i = 1; i < inputs.size(); ++i) { @@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod ofs << " #scope: " << node->scope()->name(); } -void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector& nodes, - const FuncGraphPtr& func_graph) { +void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector &nodes, + const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } int idx = 1; std::map apply_map; - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector } auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); // non-return node if (node != func_graph->get_return()) { @@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector } } -void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) { +void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun ofs << "}\n"; } -void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { +void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt ofs.close(); } -void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector& graphs) { +void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector &graphs) { if (graphs.empty()) { return; } @@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector param_index = 1; - for (const auto& tagged_graph : graphs) { + for (const auto &tagged_graph : graphs) { tagged_cnodes_ = tagged_graph.second; ExportOneFuncGraph(ofs, tagged_graph.first); tagged_cnodes_.clear(); @@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector } #ifdef ENABLE_DUMP_IR -void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) { +void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap ChangeFileMode(filename, S_IRUSR); } -void ExportIR(const std::string& filename, const std::vector& graphs) { +void ExportIR(const std::string &filename, const std::vector &graphs) { AnfExporter exporter("", false); ChangeFileMode(filename, S_IRWXU); exporter.ExportFuncGraph(filename, graphs); @@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector& graph ChangeFileMode(filename, S_IRUSR); } #else -void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { +void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { << "please recompile source to enable it. See help of building script."; } -void ExportIR(const std::string& filename, const std::vector& graphs) { +void ExportIR(const std::string &filename, const std::vector &graphs) { static bool already_printed = false; if (already_printed) { return; @@ -732,7 +732,7 @@ enum Token : int { TOK_ERROR // file read error }; -std::map token_text = { +std::map token_text = { {TOK_INVALID, "invalid"}, // invalid token {TOK_LPARENTHESIS, "("}, // ( left parenthesis {TOK_RPARENTHESIS, ")"}, // ) right parenthesis @@ -761,14 +761,14 @@ std::map token_text = { class Lexer { public: // filename is checked in ImportIR; - explicit Lexer(const char* filename) : fin(filename) {} + explicit Lexer(const char *filename) : fin(filename) {} ~Lexer() { try { if (fin.is_open()) { fin.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file"; } catch (...) { std::string exName(abi::__cxa_current_exception_type()->name()); @@ -776,7 +776,7 @@ class Lexer { } } - bool IsSingleCharToken(char ch, Token* token_ptr) { + bool IsSingleCharToken(char ch, Token *token_ptr) { // clang-format off std::unordered_map char_to_token = { {'(', TOK_LPARENTHESIS}, @@ -806,7 +806,7 @@ class Lexer { Token GetNextToken() { #ifdef DEBUG Token token = GetNextTokenInner(); - const char* str = token_text[token]; + const char *str = token_text[token]; std::string text = (str == nullptr ? GetTokenText() : str); MS_LOG(DEBUG) << "------Parse token] " << text; return token; @@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE; class IrParser { public: - explicit IrParser(const char* filename) : lexer_(filename) {} + explicit IrParser(const char *filename) : lexer_(filename) {} ~IrParser() {} - py::object LoadObject(const std::string& file_name) const { + py::object LoadObject(const std::string &file_name) const { std::string pkl_path = GetMsIrPath(); py::object default_obj = load_obj(pkl_path + "/" + file_name); return default_obj; @@ -1087,7 +1087,7 @@ class IrParser { MS_LOG(INFO) << "Total graphs: " << func_graphs_.size(); } - Token ParseParent(FuncGraphPtr* const parent_ptr) { + Token ParseParent(FuncGraphPtr *const parent_ptr) { if (lexer_.GetNextToken() != TOK_IDENTIFIER) { return TOK_ERROR; } @@ -1168,7 +1168,7 @@ class IrParser { return func_graph; } - FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) { + FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) { Token tok = lexer_.SkipWhiteToken(); while (tok == TOK_VARIABLE) { if (ParseStatement(func_graph) == nullptr) { @@ -1264,56 +1264,56 @@ class IrParser { return func_graph; } - void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const { + void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const { if (ptr == nullptr) { return; } *ptr = dtype; } - void SetTupleType(TypePtr* ptr) { + void SetTupleType(TypePtr *ptr) { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetTupleType(TypePtr* ptr, const TypePtrList& elems) { + void SetTupleType(TypePtr *ptr, const TypePtrList &elems) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elems); } - void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector&) { + void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector &) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem_type); } - void SetListType(TypePtr* ptr) { + void SetListType(TypePtr *ptr) { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetListType(TypePtr* ptr, const TypePtrList& elems) { + void SetListType(TypePtr *ptr, const TypePtrList &elems) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elems); } - void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) { + void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem); } - void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const { + void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const { if (ptr == nullptr) { return; } @@ -1321,45 +1321,45 @@ class IrParser { } // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {} - void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const { + void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {} - void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {} + void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {} + void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {} - void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { + void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { if (ptr == nullptr) { return; } // if one of elems is nullptr, just return - if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { + if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { return; } *ptr = std::make_shared(elems); } - void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector& shape) { + void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector &shape) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem_type, shape); } - void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { + void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { if (ptr == nullptr) { return; } - if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { + if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { return; } *ptr = std::make_shared(elems); } - void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) { + void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) { if (ptr == nullptr) { return; } @@ -1367,7 +1367,7 @@ class IrParser { } template - Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) { + Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) { if (tok != TOK_LBRACKET) { MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol."; return tok; @@ -1415,7 +1415,7 @@ class IrParser { } template - Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { + Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { if (tok != TOK_LPARENTHESIS) { if (ptr != nullptr) { SetBasicType(ptr, std::make_shared()); @@ -1454,7 +1454,7 @@ class IrParser { return lexer_.GetNextToken(); } - bool IsNumberType(const std::string& type, TypeId* typeid_ptr) { + bool IsNumberType(const std::string &type, TypeId *typeid_ptr) { // clang-format off static std::unordered_map basic_types = { {"Bool", kNumberTypeBool}, @@ -1486,7 +1486,7 @@ class IrParser { } template - void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) { + void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) { TypePtr dtype = nullptr; std::unordered_map type_map = { @@ -1519,7 +1519,7 @@ class IrParser { } template - Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) { + Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) { if (type == "NoneType") { SetBasicType(ptr, std::make_shared()); return lexer_.GetNextToken(); @@ -1541,7 +1541,7 @@ class IrParser { } template - Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { + Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { if (tok != TOK_IDENTIFIER) { return TOK_ERROR; } @@ -1588,11 +1588,11 @@ class IrParser { } } - Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) { + Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) { return ParseOneType(func_graph, lexer_.GetNextToken(), abstract); } - Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { + Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { Token tok = ParseAttribute(func_graph, prim); while (tok == TOK_COMMA) { tok = ParseAttribute(func_graph, prim); @@ -1603,7 +1603,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { + Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { Token tok = lexer_.GetNextToken(); if (tok != TOK_IDENTIFIER) { return TOK_ERROR; @@ -1670,7 +1670,7 @@ class IrParser { return tok == TOK_RPARENTHESIS ? func_graph : nullptr; } - FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector* const inputs_ptr) { + FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector *const inputs_ptr) { Token tok = ParseArgument(func_graph, inputs_ptr); while (tok == TOK_COMMA) { tok = ParseArgument(func_graph, inputs_ptr); @@ -1681,9 +1681,9 @@ class IrParser { return func_graph; } - AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) { + AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string ¶m_name) { while (func_graph != nullptr) { - for (auto& ptr : func_graph->parameters()) { + for (auto &ptr : func_graph->parameters()) { MS_EXCEPTION_IF_NULL(ptr); ParameterPtr param = ptr->cast(); MS_EXCEPTION_IF_NULL(param); @@ -1701,12 +1701,12 @@ class IrParser { return nullptr; } - bool Match(const std::string& str, const std::string& pattern) const { + bool Match(const std::string &str, const std::string &pattern) const { return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0; } template - Token ParseScalar(ValuePtr* const val_ptr) { + Token ParseScalar(ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_NUMBER) { return TOK_ERROR; } @@ -1725,7 +1725,7 @@ class IrParser { } template - Token ParseScalar(ValuePtr* const val_ptr, Token tok) { + Token ParseScalar(ValuePtr *const val_ptr, Token tok) { if (tok != TOK_LPARENTHESIS) { *val_ptr = std::make_shared(); return tok; @@ -1735,7 +1735,7 @@ class IrParser { } template - Token ParseScalar(ValuePtr* const val_ptr, Token tok) { + Token ParseScalar(ValuePtr *const val_ptr, Token tok) { if (tok != TOK_LPARENTHESIS) { *val_ptr = std::make_shared(nbits); return tok; @@ -1745,7 +1745,7 @@ class IrParser { } template - T StringToScalar(const std::string& text) { + T StringToScalar(const std::string &text) { std::stringstream ss; T value; ss << text; @@ -1753,7 +1753,7 @@ class IrParser { return value; } - Token ParseTensor(ValuePtr* const val_ptr) { + Token ParseTensor(ValuePtr *const val_ptr) { // parse type TypeId type; if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { @@ -1803,7 +1803,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParsePrimType(Token tok, PrimType* prim_type_ptr) { + Token ParsePrimType(Token tok, PrimType *prim_type_ptr) { if (tok != TOK_LBRACE) { return tok; } @@ -1830,7 +1830,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { + Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { if (tok != TOK_LPARENTHESIS) { return TOK_ERROR; } @@ -1855,7 +1855,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { + Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { if (tok != TOK_LBRACE) { return tok; } @@ -1868,7 +1868,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseBoolValue(const std::string& key, bool* val_ptr) { + Token ParseBoolValue(const std::string &key, bool *val_ptr) { if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) { return TOK_ERROR; } @@ -1892,7 +1892,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) { + Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_LBRACE) { return TOK_ERROR; } @@ -1920,7 +1920,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) { + Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) { if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { return TOK_ERROR; } @@ -1951,7 +1951,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) { + Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_AT_FILE) { return TOK_ERROR; } @@ -1984,7 +1984,7 @@ class IrParser { return next; } - Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) { + Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) { if (Match(id, "MultitypeFuncGraph::")) { std::string name = id.substr(strlen("MultitypeFuncGraph::")); auto mt_func_graph = std::make_shared(name); @@ -2024,8 +2024,8 @@ class IrParser { } } - Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr, - AnfNodePtr* const node_ptr = nullptr) { + Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr, + AnfNodePtr *const node_ptr = nullptr) { if (id == "None") { *val_ptr = std::make_shared(); return lexer_.GetNextToken(); @@ -2075,9 +2075,9 @@ class IrParser { } } - Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid, - const std::vector& elems, const std::vector& nodes, - ValuePtr* const val_ptr, AnfNodePtr* node_ptr) { + Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid, + const std::vector &elems, const std::vector &nodes, + ValuePtr *const val_ptr, AnfNodePtr *node_ptr) { if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) { if (node_is_valid && node_ptr != nullptr) { MS_EXCEPTION_IF_NULL(func_graph); @@ -2097,8 +2097,8 @@ class IrParser { } } - Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, - AnfNodePtr* node_ptr = nullptr) { + Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, + AnfNodePtr *node_ptr = nullptr) { Token left_tok = tok; std::vector elems; @@ -2138,7 +2138,7 @@ class IrParser { return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr); } - Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) { + Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) { // tuple or list if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) { return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr); @@ -2152,7 +2152,7 @@ class IrParser { return TOK_ERROR; } - Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr, + Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr, Token tok = TOK_INVALID) { if (tok == TOK_INVALID) { tok = lexer_.GetNextToken(); @@ -2193,7 +2193,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseArgument(const FuncGraphPtr& func_graph, std::vector* const inputs_ptr) { + Token ParseArgument(const FuncGraphPtr &func_graph, std::vector *const inputs_ptr) { Token tok = lexer_.GetNextToken(); if (tok == TOK_RPARENTHESIS) { return tok; @@ -2208,7 +2208,7 @@ class IrParser { return tok; } - const std::vector& GetFuncGraphs() const { return func_graphs_; } + const std::vector &GetFuncGraphs() const { return func_graphs_; } private: Lexer lexer_; @@ -2226,14 +2226,14 @@ class IrParser { std::map param_nodes_; // map parameter name to parameter }; -std::vector ImportIR(const std::string& filename) { +std::vector ImportIR(const std::string &filename) { IrParser parser(filename.c_str()); parser.ParseFile(); return parser.GetFuncGraphs(); } #ifdef ENABLE_DUMP_IR -void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { +void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { if (func_graph == nullptr) { MS_LOG(ERROR) << "Func graph is nullptr"; return; @@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { return; } char real_path[PATH_MAX] = {0}; - char* real_path_ret = nullptr; + char *real_path_ret = nullptr; #if defined(_WIN32) || defined(_WIN64) real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); #else @@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { ChangeFileMode(file_path, S_IRUSR); } #else -void DumpIRProto(const FuncGraphPtr&, const std::string&) { +void DumpIRProto(const FuncGraphPtr &, const std::string &) { static bool already_printed = false; if (already_printed) { return; diff --git a/mindspore/ccsrc/debug/anf_ir_utils.h b/mindspore/ccsrc/debug/anf_ir_utils.h index 5342c1ab965..6c8601c4af4 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.h +++ b/mindspore/ccsrc/debug/anf_ir_utils.h @@ -39,7 +39,7 @@ namespace mindspore { struct ParamPtrEqual { - bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const { + bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const { const ParameterPtr param1 = dyn_cast(t1); const ParameterPtr param2 = dyn_cast(t2); @@ -52,7 +52,7 @@ struct ParamPtrEqual { }; struct ParamPtrHasher { - std::size_t operator()(AnfNodePtr const& param) const { + std::size_t operator()(AnfNodePtr const ¶m) const { const ParameterPtr parameter = dyn_cast(param); if (parameter == nullptr) { return 0; @@ -64,39 +64,39 @@ struct ParamPtrHasher { class AnfExporter { public: - explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) + explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false) : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { func_graph_set.clear(); exported.clear(); } virtual ~AnfExporter() {} - void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); - void ExportFuncGraph(const std::string& filename, const std::vector& graphs); + void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); + void ExportFuncGraph(const std::string &filename, const std::vector &graphs); protected: - virtual std::string GetNodeType(const AnfNodePtr& nd); - int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); - int GetParamIndexFromExported(const AnfNodePtr& param); - std::string DumpObject(const py::object& obj, const std::string& category) const; - std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node); - std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph); - std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst); - std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetPrimitiveText(const PrimitivePtr& prim); - std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetNameSpaceText(const parse::NameSpacePtr& ns); - std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph); - std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map); - void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph); - void OutputParameters(std::ofstream& ofs, const std::vector& parameters, - OrderedMap* param_map); + virtual std::string GetNodeType(const AnfNodePtr &nd); + int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true); + int GetParamIndexFromExported(const AnfNodePtr ¶m); + std::string DumpObject(const py::object &obj, const std::string &category) const; + std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node); + std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph); + std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst); + std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetPrimitiveText(const PrimitivePtr &prim); + std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetNameSpaceText(const parse::NameSpacePtr &ns); + std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph); + std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map); + void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); + void OutputParameters(std::ofstream &ofs, const std::vector ¶meters, + OrderedMap *param_map); - void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node); - void OutputCNodes(std::ofstream& ofs, const std::vector& nodes, const FuncGraphPtr& func_graph); + void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node); + void OutputCNodes(std::ofstream &ofs, const std::vector &nodes, const FuncGraphPtr &func_graph); int param_index; OrderedSet func_graph_set{}; @@ -108,16 +108,16 @@ class AnfExporter { abstract::AnfNodeConfigPtr node_cfg_ = nullptr; }; -void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); -void ExportIR(const std::string& filename, const std::vector& graphs); +void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph); +void ExportIR(const std::string &filename, const std::vector &graphs); -std::vector ImportIR(const std::string& filename); +std::vector ImportIR(const std::string &filename); -std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); +std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); -void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); +void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); -std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index 3e8cbfba194..d3b92532fac 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -34,7 +34,7 @@ namespace draw { namespace { // Only for ValueNode -std::string ValueType(const ValueNodePtr& node) { +std::string ValueType(const ValueNodePtr &node) { if (node == nullptr) { return ""; } @@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) { return v->type_name(); } -std::string ReplaceSpecialChar(const std::string& str) { +std::string ReplaceSpecialChar(const std::string &str) { std::ostringstream oss; for (size_t i = 0; i < str.size(); i++) { if (str[i] == '<') { @@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) { } // namespace // API of debug utils -void DrawNodes(const std::vector& nodes, OrderedMap>* sub_graphs, +void DrawNodes(const std::vector &nodes, OrderedMap> *sub_graphs, bool is_user) { if (sub_graphs == nullptr) { return; } - for (auto& nd : nodes) { + for (auto &nd : nodes) { MS_EXCEPTION_IF_NULL(nd); auto sub_graph = nd->func_graph(); if (sub_graph != nullptr) { @@ -84,16 +84,16 @@ void DrawNodes(const std::vector& nodes, OrderedMap& nodes, - OrderedMap>* sub_graphs) { +void DrawValueNodes(const std::vector &nodes, + OrderedMap> *sub_graphs) { if (sub_graphs == nullptr) { return; } int dup_idx = 0; - for (auto& nd : nodes) { - for (auto& t : SuccIncoming(nd)) { + for (auto &nd : nodes) { + for (auto &t : SuccIncoming(nd)) { MS_EXCEPTION_IF_NULL(t); MS_EXCEPTION_IF_NULL(nd); if (t->isa() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { @@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector& nodes, } } -void DrawEdges(const std::vector& nodes, const std::shared_ptr& digraph, bool is_user) { +void DrawEdges(const std::vector &nodes, const std::shared_ptr &digraph, bool is_user) { if (digraph == nullptr) { return; } @@ -120,11 +120,11 @@ void DrawEdges(const std::vector& nodes, const std::shared_ptrisa() || t->isa()) { if ((!is_user) || (i != 0)) { @@ -143,7 +143,7 @@ void DrawEdges(const std::vector& nodes, const std::shared_ptrSubGraph(gsub.first, gsub.second); } @@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use } #ifdef ENABLE_DUMP_IR -void Draw(const std::string& filename, const FuncGraphPtr& func_graph) { +void Draw(const std::string &filename, const FuncGraphPtr &func_graph) { const std::string dot_suffix = ".dot"; std::string filename_with_suffix = (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; DrawByOpt(filename_with_suffix, func_graph, false); } -void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { +void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { DrawByOpt(filename, func_graph, true); } #else -void Draw(const std::string&, const FuncGraphPtr&) { +void Draw(const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) { << "please recompile source to enable it. See help of building script."; } -void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) { +void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) { return "plaintext"; } -std::string Graphviz::Color(const AnfNodePtr& node) { +std::string Graphviz::Color(const AnfNodePtr &node) { if (node == nullptr) { return ""; } @@ -259,7 +259,7 @@ void BaseDigraph::Start() { buffer_ << "compound=true" << std::endl; } -void BaseDigraph::Head(const AnfNodePtr& node, int id) { +void BaseDigraph::Head(const AnfNodePtr &node, int id) { if (node == nullptr) { return; } @@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) { } } -void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { +void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) { if (node == nullptr) { return; } @@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { buffer_ << ":" << idx; } -void BaseDigraph::Tail(const FuncGraphPtr& func_graph) { +void BaseDigraph::Tail(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -304,12 +304,12 @@ void BaseDigraph::End() { } } -void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { +void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { buffer_ << "parameters_" << key << "[shape=plaintext "; buffer_ << "label=<"; buffer_ << ""; int count = 0; - for (auto& parameter : key->parameters()) { + for (auto ¶meter : key->parameters()) { buffer_ << "
parameters
"; buffer_ << parameter->ToString(); auto py_p = dyn_cast(parameter)->default_param(); @@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { buffer_ << "
>,];"; } -void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr& gsub) { +void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr &gsub) { if (key == nullptr || gsub == nullptr) { return; } @@ -361,12 +361,12 @@ Digraph::~Digraph() { if (fout_.is_open()) { fout_.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file " << filename_; } } -static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { +static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) { size_t start_pos = 0; while ((start_pos = str.find(from, start_pos)) != std::string::npos) { (void)str.replace(start_pos, from.length(), to); @@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st return str; } -static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { +static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(graph_obj); graph_obj->buffer() << "label=<"; @@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { graph_obj->buffer() << ""; graph_obj->buffer() << "
"; int i = 0; - for (const auto& attr : attrs) { + for (const auto &attr : attrs) { if (i != 0) { graph_obj->buffer() << "
"; } @@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { graph_obj->buffer() << "
>,"; } -static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { +static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr) { return; } @@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { } } -static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { +static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr || node->size() == 0) { return; } @@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { } graph_obj->buffer() << ">"; int i = 0; - for (auto& attr : attrs) { + for (auto &attr : attrs) { if (i != 0) { graph_obj->buffer() << "
"; } @@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() { if (fout_.is_open()) { fout_.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "exception when closing file " << filename_; } } diff --git a/mindspore/ccsrc/debug/draw.h b/mindspore/ccsrc/debug/draw.h index 4781a6c231f..7804c6e94a6 100644 --- a/mindspore/ccsrc/debug/draw.h +++ b/mindspore/ccsrc/debug/draw.h @@ -31,9 +31,9 @@ namespace parse = mindspore::parse; class Graphviz { public: - Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {} + Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {} - explicit Graphviz(const std::string& name) : name_(name) {} + explicit Graphviz(const std::string &name) : name_(name) {} virtual ~Graphviz() {} @@ -41,8 +41,8 @@ class Graphviz { virtual void End() {} virtual std::string Shape(AnfNodePtr node); - std::string Color(const AnfNodePtr& node); - std::ostringstream& buffer() { return buffer_; } + std::string Color(const AnfNodePtr &node); + std::ostringstream &buffer() { return buffer_; } std::ostringstream buffer_; protected: @@ -53,8 +53,8 @@ class Graphviz { class BaseDigraph : public Graphviz { public: - BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {} - explicit BaseDigraph(const std::string& name) : Graphviz(name) {} + BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {} + explicit BaseDigraph(const std::string &name) : Graphviz(name) {} ~BaseDigraph() override = default; virtual void Node(AnfNodePtr node, int id = 0) = 0; @@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz { void Start() override; void End() override; virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); - void FuncGraphParameters(const FuncGraphPtr& key); - void SubGraph(const FuncGraphPtr& key, const std::shared_ptr& gsub); + void FuncGraphParameters(const FuncGraphPtr &key); + void SubGraph(const FuncGraphPtr &key, const std::shared_ptr &gsub); - const std::string& name() const { return name_; } + const std::string &name() const { return name_; } protected: - void Head(const AnfNodePtr& node, int id = 0); - void Tail(const AnfNodePtr& node, int idx, int id = 0); - void Tail(const FuncGraphPtr& func_graph); + void Head(const AnfNodePtr &node, int id = 0); + void Tail(const AnfNodePtr &node, int idx, int id = 0); + void Tail(const FuncGraphPtr &func_graph); }; class Digraph : public BaseDigraph { public: - Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} - explicit Digraph(const std::string& name) : BaseDigraph(name) {} + Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} + explicit Digraph(const std::string &name) : BaseDigraph(name) {} ~Digraph() override; void Node(AnfNodePtr node, int id = 0) override; @@ -86,8 +86,8 @@ class Digraph : public BaseDigraph { class ModelDigraph : public BaseDigraph { public: - ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} - explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {} + ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} + explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {} ~ModelDigraph() override; std::string Shape(AnfNodePtr node) override; @@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph { }; // API to draw -void Draw(const std::string& filename, const FuncGraphPtr& func_graph); -void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); +void Draw(const std::string &filename, const FuncGraphPtr &func_graph); +void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); } // namespace draw } // namespace mindspore diff --git a/mindspore/ccsrc/debug/dump_proto.cc b/mindspore/ccsrc/debug/dump_proto.cc index a7a1e208a4e..83ab1e45051 100644 --- a/mindspore/ccsrc/debug/dump_proto.cc +++ b/mindspore/ccsrc/debug/dump_proto.cc @@ -33,38 +33,38 @@ class ProtoExporter { ProtoExporter() {} ~ProtoExporter() {} - std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); + std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); private: void InitModelInfo(); - void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto); - std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map, - std::map* const_map_ptr); - void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto); - void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto); - void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto); - void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto); - void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto); - void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto); + void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto); + std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map, + std::map *const_map_ptr); + void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto); + void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto); + void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto); + void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto); + void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto); + void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto); - void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); - void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); - void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, - std::map* const_map_ptr); - void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* apply_map_ptr, - std::map* const_map_ptr, irpb::GraphProto* graph_proto); - void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, - const std::map& apply_map, std::map* const_map_ptr, - irpb::GraphProto* graph_proto); - void ExportValueNodes(const std::map& const_map, irpb::GraphProto* graph_proto); + void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); + void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, + std::map *const_map_ptr); + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *apply_map_ptr, + std::map *const_map_ptr, irpb::GraphProto *graph_proto); + void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, + const std::map &apply_map, std::map *const_map_ptr, + irpb::GraphProto *graph_proto); + void ExportValueNodes(const std::map &const_map, irpb::GraphProto *graph_proto); static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } irpb::ModelProto model_; }; -static irpb::DataType GetNumberDataType(const TypePtr& type) { +static irpb::DataType GetNumberDataType(const TypePtr &type) { switch (type->type_id()) { case kNumberTypeBool: return irpb::DT_BOOL; @@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) { } } -void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) { +void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) { if (type_proto == nullptr) { return; } @@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s type_proto->set_data_type(irpb::DT_TENSOR); if (shape != nullptr && shape->isa()) { abstract::ShapePtr shape_info = dyn_cast(shape); - for (const auto& elem : shape_info->shape()) { + for (const auto &elem : shape_info->shape()) { type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); } } } else if (type->isa()) { TuplePtr tuple_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_TUPLE); - for (const auto& elem_type : tuple_type->elements()) { + for (const auto &elem_type : tuple_type->elements()) { SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); } } else if (type->isa()) { @@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s } else if (type->isa()) { ListPtr list_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_LIST); - for (const auto& elem_type : list_type->elements()) { + for (const auto &elem_type : list_type->elements()) { SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); } } else if (type->isa()) { @@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s } } -void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) { +void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) { if (node == nullptr || type_proto == nullptr) { return; } SetNodeOutputType(node->Type(), node->Shape(), type_proto); } -void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const StringImmPtr& value = dyn_cast(val); + const StringImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_STRING); value_proto->set_str_val(value->value()); } else if (val->isa()) { @@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value } else if (val->isa()) { tensor::TensorPtr tensor_ptr = dyn_cast(val); value_proto->set_dtype(irpb::DT_TENSOR); - irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val(); + irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val(); tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype())); - for (auto& elem : tensor_ptr->shape()) { + for (auto &elem : tensor_ptr->shape()) { tensor_proto->add_dims(elem); } } else if (val->isa()) { value_proto->set_dtype(irpb::DT_TYPE); - irpb::TypeProto* type_proto = value_proto->mutable_type_val(); + irpb::TypeProto *type_proto = value_proto->mutable_type_val(); type_proto->set_data_type(irpb::DT_TENSOR); TypePtr elem_type = dyn_cast(val)->element(); type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); @@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value } } -void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const BoolImmPtr& value = dyn_cast(val); + const BoolImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_BOOL); value_proto->set_bool_val(value->value()); } else if (val->isa()) { - const Int8ImmPtr& value = dyn_cast(val); + const Int8ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT8); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int16ImmPtr& value = dyn_cast(val); + const Int16ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT16); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int32ImmPtr& value = dyn_cast(val); + const Int32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT32); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int64ImmPtr& value = dyn_cast(val); + const Int64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT64); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const UInt8ImmPtr& value = dyn_cast(val); + const UInt8ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT8); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt16ImmPtr& value = dyn_cast(val); + const UInt16ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT16); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt32ImmPtr& value = dyn_cast(val); + const UInt32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT32); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt64ImmPtr& value = dyn_cast(val); + const UInt64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT64); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const FP32ImmPtr& value = dyn_cast(val); + const FP32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_FLOAT32); value_proto->set_float_val(value->value()); } else if (val->isa()) { - const FP64ImmPtr& value = dyn_cast(val); + const FP64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_FLOAT64); value_proto->set_double_val(value->value()); } else { @@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val } } -void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const ValueTuplePtr& value = dyn_cast(val); + const ValueTuplePtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_TUPLE); - for (const auto& item : value->value()) { + for (const auto &item : value->value()) { SetValueToProto(item, value_proto->add_values()); } } else if (val->isa()) { - const ValueListPtr& value = dyn_cast(val); + const ValueListPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_LIST); - for (const auto& item : value->value()) { + for (const auto &item : value->value()) { SetValueToProto(item, value_proto->add_values()); } } } -void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } value_proto->set_dtype(irpb::DT_DICT); - for (const auto& item : val->value()) { - irpb::NamedValueProto* named_val = value_proto->add_dict_val(); + for (const auto &item : val->value()) { + irpb::NamedValueProto *named_val = value_proto->add_dict_val(); named_val->set_key(item.first); SetValueToProto(item.second, named_val->mutable_value()); } } -void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) { +void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) { if (node == nullptr || node_proto == nullptr) { return; } @@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); } - const PrimitivePtr& prim = GetValueNode(node); + const PrimitivePtr &prim = GetValueNode(node); node_proto->set_op_type(prim->name()); - for (const auto& attr : prim->attrs()) { - irpb::AttributeProto* attr_proto = node_proto->add_attribute(); + for (const auto &attr : prim->attrs()) { + irpb::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name(attr.first); SetValueToProto(attr.second, attr_proto->mutable_value()); } node_proto->set_scope(node->scope()->name()); } -std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node, - const std::map& apply_map, - std::map* const_map_ptr) { +std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node, + const std::map &apply_map, + std::map *const_map_ptr) { if (node == nullptr || const_map_ptr == nullptr) { return ""; } @@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; } -std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { +std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return ""; } InitModelInfo(); - irpb::GraphProto* graph_proto = model_.mutable_graph(); + irpb::GraphProto *graph_proto = model_.mutable_graph(); ExportFuncGraph(func_graph, graph_proto); return model_.SerializeAsString(); } -void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || graph_proto == nullptr) { return; } @@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP ExportValueNodes(const_map, graph_proto); } -void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || graph_proto == nullptr) { return; } std::vector parameters = func_graph->parameters(); - for (auto& param : parameters) { - irpb::ParameterProto* param_proto = graph_proto->add_parameters(); + for (auto ¶m : parameters) { + irpb::ParameterProto *param_proto = graph_proto->add_parameters(); param_proto->set_name(param->ToString()); SetNodeOutputType(param, param_proto->mutable_type()); @@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph } } -void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, - std::map* const_map_ptr) { +void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, + std::map *const_map_ptr) { if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { return; } // topo sort nodes std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::map apply_map; - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt } } -void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* apply_map_ptr, - std::map* const_map_ptr, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *apply_map_ptr, + std::map *const_map_ptr, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || graph_proto == nullptr) { return; @@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& auto apply_idx = apply_map_ptr->size() + 1; (*apply_map_ptr)[node] = apply_idx; - auto& inputs = node->inputs(); + auto &inputs = node->inputs(); if (inputs.size() < 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } AnfNodePtr op = inputs[0]; - irpb::NodeProto* node_proto = graph_proto->add_node(); + irpb::NodeProto *node_proto = graph_proto->add_node(); // CNode/ConstGraph/Const/Parameter if (op->isa() || IsValueNode(op) || op->isa()) { @@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& // process OP inputs for (size_t i = 1; i < inputs.size(); ++i) { - irpb::InputProto* input_proto = node_proto->add_input(); + irpb::InputProto *input_proto = node_proto->add_input(); input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE); std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); input_proto->set_name(id); @@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& } } -void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, - const std::map& apply_map, - std::map* const_map_ptr, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, + const std::map &apply_map, + std::map *const_map_ptr, irpb::GraphProto *graph_proto) { if (ret_node == nullptr || !ret_node->isa()) { MS_LOG(EXCEPTION) << "Graph return node is illegal"; } @@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const if (graph_proto == nullptr) { MS_LOG(EXCEPTION) << "graph_proto is nullptr"; } - irpb::OutputProto* output_proto = graph_proto->add_outputs(); + irpb::OutputProto *output_proto = graph_proto->add_outputs(); if (output_proto == nullptr) { MS_LOG(EXCEPTION) << "output_proto is nullptr"; } @@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const SetNodeOutputType(arg, output_proto->mutable_type()); } -static bool CompareValue(const std::pair& x, const std::pair& y) { +static bool CompareValue(const std::pair &x, const std::pair &y) { return x.second < y.second; } -void ProtoExporter::ExportValueNodes(const std::map& const_map, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportValueNodes(const std::map &const_map, irpb::GraphProto *graph_proto) { std::vector> nodes; (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), - [](const std::pair& item) { return item; }); + [](const std::pair &item) { return item; }); sort(nodes.begin(), nodes.end(), CompareValue); - for (auto& item : nodes) { + for (auto &item : nodes) { if (graph_proto == nullptr) { MS_LOG(EXCEPTION) << "graph_proto is nullptr"; } - irpb::NamedValueProto* named_value = graph_proto->add_const_vals(); + irpb::NamedValueProto *named_value = graph_proto->add_const_vals(); MS_EXCEPTION_IF_NULL(named_value); named_value->set_key(GetConstNodeId(item.second)); SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); @@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map& const_m void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); } -std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { +std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { ProtoExporter exporter; return exporter.GetFuncGraphProtoString(func_graph); } diff --git a/mindspore/ccsrc/debug/e2e_dump.cc b/mindspore/ccsrc/debug/e2e_dump.cc index fbe76cdc474..34d401191a2 100644 --- a/mindspore/ccsrc/debug/e2e_dump.cc +++ b/mindspore/ccsrc/debug/e2e_dump.cc @@ -36,7 +36,7 @@ Dump::Dump() dump_iter_(0), cur_iter_(0) {} -bool Dump::IsKernelNeedDump(const std::string& kernel_name) { +bool Dump::IsKernelNeedDump(const std::string &kernel_name) { if (dump_mode_ == 0) { // Dump All Kernels mode return true; @@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) { return false; } -bool Dump::ParseDumpConfig(const std::string& dump_config_file) { +bool Dump::ParseDumpConfig(const std::string &dump_config_file) { std::ifstream jsonFile(dump_config_file); if (!jsonFile.is_open()) { MS_LOG(ERROR) << dump_config_file << " open failed."; @@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) { return true; } -bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { +bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) { if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || @@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { return true; } -bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { +bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { auto trans_flag = dumpSettings.at("trans_flag"); auto enable = dumpSettings.at("enable"); auto mode = dumpSettings.at("mode"); @@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { dump_path_ = path; dump_net_name_ = net_name; dump_iter_ = iteration; - for (const auto& kernel : kernels) { + for (const auto &kernel : kernels) { dump_kernels_.push_back(kernel); } return true; } bool Dump::SetDumpConfFromJsonFile() { - const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); + const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); if (config_path_str != nullptr) { MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; } else { @@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() { return ParseDumpConfig(dump_config_file); } -bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) { +bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) { if (filename.empty() || data == nullptr || len == 0) { MS_LOG(ERROR) << "Incorrect parameter."; return false; @@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) MS_LOG(ERROR) << "Open file " << realpath << " fail."; return false; } - (void)fd.write(reinterpret_cast(data), SizeToLong(len)); + (void)fd.write(reinterpret_cast(data), SizeToLong(len)); fd.close(); return true; } -bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { +bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) { MS_EXCEPTION_IF_NULL(outpath); auto path_split_pos = inpath.find_last_of('/'); if (path_split_pos == std::string::npos) { @@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { return true; } -bool Dump::CreateNotExistDirs(const std::string& path) { +bool Dump::CreateNotExistDirs(const std::string &path) { std::shared_ptr fs = system::Env::GetFileSystem(); MS_EXCEPTION_IF_NULL(fs); char temp_path[PATH_MAX] = {0}; diff --git a/mindspore/ccsrc/debug/e2e_dump.h b/mindspore/ccsrc/debug/e2e_dump.h index 2410dfb09af..4c3e8308da7 100644 --- a/mindspore/ccsrc/debug/e2e_dump.h +++ b/mindspore/ccsrc/debug/e2e_dump.h @@ -43,11 +43,11 @@ class Dump { uint32_t cur_iter() const { return cur_iter_; } - bool IsKernelNeedDump(const std::string& kernel_name); + bool IsKernelNeedDump(const std::string &kernel_name); bool SetDumpConfFromJsonFile(); - static bool DumpToFile(const std::string& filename, const void* data, size_t len); + static bool DumpToFile(const std::string &filename, const void *data, size_t len); protected: bool dump_enable_; @@ -59,14 +59,14 @@ class Dump { uint32_t cur_iter_; std::vector dump_kernels_; - static bool GetRealPath(const std::string& inpath, std::string* outpath); + static bool GetRealPath(const std::string &inpath, std::string *outpath); - static bool CreateNotExistDirs(const std::string& path); + static bool CreateNotExistDirs(const std::string &path); private: - bool ParseDumpConfig(const std::string& dump_config_file); - bool IsConfigExist(const nlohmann::json& dumpSettings); - bool IsConfigValid(const nlohmann::json& dumpSettings); + bool ParseDumpConfig(const std::string &dump_config_file); + bool IsConfigExist(const nlohmann::json &dumpSettings); + bool IsConfigValid(const nlohmann::json &dumpSettings); }; using DumpConfPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/debug/info.cc b/mindspore/ccsrc/debug/info.cc index 3c43bfa9b13..7903e554d90 100644 --- a/mindspore/ccsrc/debug/info.cc +++ b/mindspore/ccsrc/debug/info.cc @@ -23,7 +23,7 @@ #include "pipeline/parse/python_adapter.h" namespace mindspore { -std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) { +std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { std::string temp_line = line; if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && tip != kSourceLineTipDiscard) { @@ -101,14 +101,14 @@ DebugInfo::DebugInfo() { name_ = ""; } -DebugInfo::DebugInfo(const std::string& name) { +DebugInfo::DebugInfo(const std::string &name) { InitValueFromContext(); unique_id_ = gen_unique_id(); debug_id_ = -1; name_ = name; } -DebugInfo::DebugInfo(const LocationPtr& loc) { +DebugInfo::DebugInfo(const LocationPtr &loc) { InitValueFromContext(); unique_id_ = gen_unique_id(); debug_id_ = -1; @@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() { } int64_t DebugInfo::unique_id_through_copy() const { - TraceInfoPtr trace_info = const_cast(this)->trace_info(); + TraceInfoPtr trace_info = const_cast(this)->trace_info(); if (trace_info != nullptr) { if (trace_info->isa() && trace_info->debug_info() != nullptr) { return trace_info->debug_info()->unique_id_through_copy(); @@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() { } return DebugInfo::location(); } -void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; } +void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; } TraceContextPtr TraceManager::CurrentContextInfo() { if (!TraceManager::trace_context_stack_.empty()) { @@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() { return nullptr; } -void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) { +void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) { TraceContextPtr context = std::make_shared(location); context->set_func_name(func_name); TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const LocationPtr& location) { +void TraceManager::DebugTrace(const LocationPtr &location) { TraceContextPtr context = std::make_shared(location); TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { +void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) { if (trace_info == nullptr) { MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; } @@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) { +void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) { if (trace_info == nullptr) { MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; } diff --git a/mindspore/ccsrc/debug/info.h b/mindspore/ccsrc/debug/info.h index da641ab74bb..a34d6e3df51 100644 --- a/mindspore/ccsrc/debug/info.h +++ b/mindspore/ccsrc/debug/info.h @@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou // Location class record the location in source code. class Location { public: - Location(const std::string& file_name, int line, int column, int line_end, int column_end) + Location(const std::string &file_name, int line, int column, int line_end, int column_end) : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} - Location(const Location& loc) + Location(const Location &loc) : file_name_(loc.file_name_), line_(loc.line_), column_(loc.column_), @@ -77,21 +77,21 @@ class TraceManager { TraceManager() = default; ~TraceManager() = default; static TraceContextPtr CurrentContextInfo(); - static void DebugTrace(const std::string& func_name, const LocationPtr& location); - static void DebugTrace(const LocationPtr& location); - static void DebugTrace(const TraceInfoPtr& trace_info); + static void DebugTrace(const std::string &func_name, const LocationPtr &location); + static void DebugTrace(const LocationPtr &location); + static void DebugTrace(const TraceInfoPtr &trace_info); // debug trace with a cloned trace info with debug_info - static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info); + static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info); static void EndTrace(); static std::stack trace_context_stack_; }; class TraceGuard { public: - explicit TraceGuard(const std::string func_name, const LocationPtr& location) { + explicit TraceGuard(const std::string func_name, const LocationPtr &location) { TraceManager::DebugTrace(func_name, location); } - explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); } + explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); } ~TraceGuard() { TraceManager::EndTrace(); } }; @@ -106,23 +106,23 @@ class TraceContext { public: ~TraceContext() = default; - explicit TraceContext(const LocationPtr& loc) { + explicit TraceContext(const LocationPtr &loc) { ProcessAttributeFromContext(); location_ = loc; } - explicit TraceContext(const std::string& func_name) { + explicit TraceContext(const std::string &func_name) { ProcessAttributeFromContext(); func_name_ = func_name; } - explicit TraceContext(const TraceInfoPtr& trace_info) { + explicit TraceContext(const TraceInfoPtr &trace_info) { ProcessAttributeFromContext(); trace_info_ = trace_info; } - void set_location(const LocationPtr& loc) { location_ = loc; } + void set_location(const LocationPtr &loc) { location_ = loc; } LocationPtr location() { return location_; } - void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } + void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } TraceInfoPtr trace_info() { return trace_info_; } - void set_func_name(const std::string& func_name) { func_name_ = func_name; } + void set_func_name(const std::string &func_name) { func_name_ = func_name; } std::string func_name() { return func_name_; } }; @@ -130,9 +130,9 @@ class DebugInfo : public Base { public: DebugInfo(); - explicit DebugInfo(const std::string& name); + explicit DebugInfo(const std::string &name); - explicit DebugInfo(const LocationPtr& loc); + explicit DebugInfo(const LocationPtr &loc); virtual ~DebugInfo() = default; MS_DECLARE_PARENT(DebugInfo, Base); @@ -141,12 +141,12 @@ class DebugInfo : public Base { int64_t unique_id_through_copy() const; std::string get_id() { return std::to_string(debug_id()); } - void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } + void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } TraceInfoPtr trace_info() { return trace_info_; } - void set_location(const LocationPtr& loc) { location_ = loc; } + void set_location(const LocationPtr &loc) { location_ = loc; } virtual LocationPtr location() { return location_; } std::string name() { return name_; } - void set_name(const std::string& name) { name_ = name; } + void set_name(const std::string &name) { name_ = name; } virtual std::string debug_name(); virtual std::string get_python_func_belonged() { return ""; } @@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo { py_func_belonged_ = context_info->func_name(); } } - explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) { + explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) { if (TraceManager::CurrentContextInfo() != nullptr) { auto context_info = TraceManager::CurrentContextInfo(); py_func_belonged_ = context_info->func_name(); @@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo { ~NodeDebugInfo() override = default; std::string debug_name() override; - void set_node(const std::shared_ptr& node) { node_ = AnfNodeWeakPtr(node); } + void set_node(const std::shared_ptr &node) { node_ = AnfNodeWeakPtr(node); } std::shared_ptr get_node() const { return node_.lock(); } - void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; } + void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; } std::string get_python_func_belonged() override { return py_func_belonged_; } AnfNodeWeakPtr node_; std::string py_func_belonged_; @@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo { } } - explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) { + explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) { if (TraceManager::CurrentContextInfo() != nullptr) { auto context_info = TraceManager::CurrentContextInfo(); py_func_name_ = context_info->func_name(); @@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo { std::string debug_name() override; LocationPtr location() override; LocationPtr deco_location() { return deco_loc_; } - void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } + void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } FuncGraphPtr get_graph() const { return func_graph_.lock(); } - void set_full_name(const std::string& name) { full_name_ = name; } + void set_full_name(const std::string &name) { full_name_ = name; } std::string get_full_name() { return full_name_; } - void set_deco_location(const LocationPtr& deco_list_loc); + void set_deco_location(const LocationPtr &deco_list_loc); std::string get_python_func_belonged() override { return py_func_name_; } FuncGraphWeakPtr func_graph_; LocationPtr deco_loc_; diff --git a/mindspore/ccsrc/debug/label.cc b/mindspore/ccsrc/debug/label.cc index f0e16e831e8..d8c4986482b 100644 --- a/mindspore/ccsrc/debug/label.cc +++ b/mindspore/ccsrc/debug/label.cc @@ -31,7 +31,7 @@ struct NameWithTrace { std::string name; std::vector trace_labels; }; -static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) { +static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) { switch (trace_label) { case TraceLabelType::kShortSymbol: return trace_info->symbol(); @@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t } } -NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { +NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { NameWithTrace trace_name; // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node auto temp_info = debug_info; @@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe return trace_name; } -std::string CombineTraceTypes(const std::string& root_name, const std::vector& trace_labels) { +std::string CombineTraceTypes(const std::string &root_name, const std::vector &trace_labels) { std::string tags = ""; - for (auto& itr : trace_labels) { + for (auto &itr : trace_labels) { std::string symbol = itr; tags = tags + symbol; } @@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { return debug_with_loc_vec; } -DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { +DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) { auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); if (debug_with_loc_vec.size() > 0) { return debug_with_loc_vec[0]; @@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { } } -std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { +std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { if (info == nullptr) { return ""; } @@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { // a trace info identifies a node transform, so we can trace the node transform through // a link of trace info and debug info -std::string GetInfoWithAction(const std::vector& info_vec, SourceLineTip tip) { +std::string GetInfoWithAction(const std::vector &info_vec, SourceLineTip tip) { if (info_vec.size() < 1) { return ""; } @@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector& info_vec, SourceL return traced_info; } -std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { +std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { if (info == nullptr) { return ""; } @@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { return ""; } -std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) { +std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) { std::ostringstream oss; if (info == nullptr) { return ""; @@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So return oss.str(); } -std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) { +std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) { std::ostringstream oss; oss << "graph:" << graph->ToString() << " with args["; auto params = graph->parameters(); @@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas return oss.str(); } -void DumpInferStack(std::ostringstream& oss) { - auto& infer_stack = GetCurrenGraphInferStack(); +void DumpInferStack(std::ostringstream &oss) { + auto &infer_stack = GetCurrenGraphInferStack(); if (infer_stack.empty()) { return; } @@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) { } std::reverse(infer_vec.begin(), infer_vec.end()); int index = 0; - for (auto& item : infer_vec) { + for (auto &item : infer_vec) { auto graph_infer = std::dynamic_pointer_cast(item.first); if (graph_infer == nullptr) { MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; @@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) { } void TraceGraphInfer() { - auto& infer_stack = GetCurrenGraphInferStack(); + auto &infer_stack = GetCurrenGraphInferStack(); std::ostringstream oss; if (infer_stack.empty()) { return; @@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter { AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} ~AnalyzedFuncGraphExporter() override = default; - void ExportFuncGraph(const std::string& filename, const std::vector& node_cfgs); + void ExportFuncGraph(const std::string &filename, const std::vector &node_cfgs); private: - std::string GetNodeType(const AnfNodePtr& nd) override; + std::string GetNodeType(const AnfNodePtr &nd) override; }; std::unordered_map CalcTaggedFuncGraphs() { std::unordered_map tagged_func_graphs; - auto& list = GetCNodeDebugStack(); + auto &list = GetCNodeDebugStack(); for (size_t i = 0; i < list.size(); ++i) { auto node_cfg = list[i]; auto fg = node_cfg->context()->func_graph(); @@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() { exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); } -std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { +std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { if (node_cfg_ == nullptr) { return AnfExporter::GetNodeType(node); } @@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { return oss.str(); } -void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, - const std::vector& node_cfgs) { +void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename, + const std::vector &node_cfgs) { if (node_cfgs.empty()) { MS_LOG(DEBUG) << "Node configs is empty"; return; @@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, auto tagged_func_graphs = CalcTaggedFuncGraphs(); // first output graph on the analysis stack - for (const auto& node_cfg : node_cfgs) { + for (const auto &node_cfg : node_cfgs) { auto fg = node_cfg->context()->func_graph(); // the graph is already output, skip it if (exported.find(fg) != exported.end()) { @@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, ofs.close(); } -void GetInferStackInfo(std::ostringstream& oss) { +void GetInferStackInfo(std::ostringstream &oss) { MS_LOG(INFO) << "Get graph analysis information begin"; auto stack = GetCNodeDebugStack(); if (stack.empty()) { @@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) { static std::stack> graph_infer_stack; // trace the cnode infer debug info static std::vector cnode_debug_stack{}; -void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) { +void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { if (eval == nullptr) { MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; } @@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An } } -void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { +void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { if (eval == nullptr) { MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; } @@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { } } -void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); } +void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } -std::vector& GetCNodeDebugStack() { return cnode_debug_stack; } +std::vector &GetCNodeDebugStack() { return cnode_debug_stack; } -std::stack>& GetCurrenGraphInferStack() { +std::stack> &GetCurrenGraphInferStack() { return graph_infer_stack; } void ClearTraceStack() { diff --git a/mindspore/ccsrc/debug/trace.h b/mindspore/ccsrc/debug/trace.h index 5fba86fddd1..2704a80a354 100644 --- a/mindspore/ccsrc/debug/trace.h +++ b/mindspore/ccsrc/debug/trace.h @@ -31,19 +31,19 @@ namespace mindspore { namespace trace { -std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine); -std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, +std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine); +std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip = kSourceLineTipNextLine); -DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info); +DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); void TraceGraphInfer(); -void GetInferStackInfo(std::ostringstream& oss); -void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node); -void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval); -void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg); +void GetInferStackInfo(std::ostringstream &oss); +void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); +void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval); +void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); void TraceInferCNodeLeave(); -std::vector& GetCNodeDebugStack(); -std::stack>& GetCurrenGraphInferStack(); -std::string GetAbstractStr(const abstract::AbstractBasePtr& abs); +std::vector &GetCNodeDebugStack(); +std::stack> &GetCurrenGraphInferStack(); +std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); void ClearTraceStack(); } // namespace trace } // namespace mindspore diff --git a/mindspore/ccsrc/debug/trace_info.cc b/mindspore/ccsrc/debug/trace_info.cc index b01cd150101..19358e197a1 100644 --- a/mindspore/ccsrc/debug/trace_info.cc +++ b/mindspore/ccsrc/debug/trace_info.cc @@ -23,7 +23,7 @@ #include "pipeline/parse/python_adapter.h" namespace mindspore { -std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) { +std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { if (info == nullptr) { return ""; } diff --git a/mindspore/ccsrc/debug/trace_info.h b/mindspore/ccsrc/debug/trace_info.h index 16be9031e2c..e7a8c83dadd 100644 --- a/mindspore/ccsrc/debug/trace_info.h +++ b/mindspore/ccsrc/debug/trace_info.h @@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr; // namespace to support intermediate representation definition class TraceInfo : public Base { public: - TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) { + TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) { symbol_ = symbol; full_name_ = full_name; name_ = full_name_; debug_info_ = info; } - TraceInfo(const TraceInfo& info) + TraceInfo(const TraceInfo &info) : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} virtual ~TraceInfo() = default; MS_DECLARE_PARENT(TraceInfo, Base); @@ -55,8 +55,8 @@ class TraceInfo : public Base { virtual std::string full_name() { return full_name_; } virtual TraceInfoPtr clone() { return shared_from_base(); } virtual std::string action_name() { return ""; } - virtual std::string GetActionBetweenNode(const DebugInfoPtr& info); - void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; } + virtual std::string GetActionBetweenNode(const DebugInfoPtr &info); + void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; } DebugInfoPtr debug_info() { return debug_info_; } DebugInfoPtr DebugInfoHasLoc(); std::vector> GetSourceCodeDebugInfo(); @@ -70,7 +70,7 @@ class TraceInfo : public Base { class TracePhi : public TraceInfo { public: - explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {} + explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {} MS_DECLARE_PARENT(TracePhi, TraceInfo); ~TracePhi() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -78,8 +78,8 @@ class TracePhi : public TraceInfo { class TraceIfStmtTrueBranch : public TraceInfo { public: - TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default; - explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "✓") {} + TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default; + explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {} MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); ~TraceIfStmtTrueBranch() override = default; TraceInfoPtr clone() override { @@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo { class TraceIfStmtFalseBranch : public TraceInfo { public: - TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default; - explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "✗") {} + TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default; + explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {} MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); ~TraceIfStmtFalseBranch() override = default; TraceInfoPtr clone() override { @@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo { class TraceIfStmtAfterBranch : public TraceInfo { public: - explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "↓") {} + explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {} MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); ~TraceIfStmtAfterBranch() override = default; TraceInfoPtr clone() override { @@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo { class TraceIfExpTrueBranch : public TraceInfo { public: - explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "↰") {} + explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {} MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); ~TraceIfExpTrueBranch() override = default; TraceInfoPtr clone() override { @@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo { class TraceIfExpFalseBranch : public TraceInfo { public: - explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "↱") {} + explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {} MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); ~TraceIfExpFalseBranch() override = default; TraceInfoPtr clone() override { @@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo { class TraceCopy : public TraceInfo { public: TraceCopy() : TraceInfo(nullptr, "copy", "") {} - explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {} + explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {} MS_DECLARE_PARENT(TraceCopy, TraceInfo); ~TraceCopy() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo { class TraceIterator : public TraceInfo { public: - explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {} + explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {} MS_DECLARE_PARENT(TraceIterator, TraceInfo); ~TraceIterator() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo { class TraceWhileHeader : public TraceInfo { public: - explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "⤾") {} + explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {} MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); ~TraceWhileHeader() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo { class TraceWhileBody : public TraceInfo { public: - explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "⥁") {} + explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {} MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); ~TraceWhileBody() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo { class TraceWhileAfter : public TraceInfo { public: - explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "↓") {} + explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {} MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); ~TraceWhileAfter() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo { class TraceForHeader : public TraceInfo { public: - explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "⤾") {} + explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {} MS_DECLARE_PARENT(TraceForHeader, TraceInfo); ~TraceForHeader() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo { class TraceForBody : public TraceInfo { public: - explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "⥁") {} + explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {} MS_DECLARE_PARENT(TraceForBody, TraceInfo); ~TraceForBody() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo { class TraceForAfter : public TraceInfo { public: - explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "↓") {} + explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {} MS_DECLARE_PARENT(TraceForAfter, TraceInfo); ~TraceForAfter() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo { class TraceEquiv : public TraceInfo { public: - explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {} + explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} MS_DECLARE_PARENT(TraceEquiv, TraceInfo); ~TraceEquiv() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo { class TraceGradFpropApp : public TraceInfo { public: TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} - explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "▲") {} + explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {} MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); ~TraceGradFpropApp() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo { class TraceGradBpropApp : public TraceInfo { public: TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} - explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "▼") {} + explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {} MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); ~TraceGradBpropApp() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo { class TraceGradFprop : public TraceInfo { public: TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} - explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "▶") {} + explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {} MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); ~TraceGradFprop() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo { class TraceGradBprop : public TraceInfo { public: TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} - explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "◀") {} + explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {} MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); ~TraceGradBprop() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo { class TraceGradSens : public TraceInfo { public: TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} - explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "∇") {} + explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {} MS_DECLARE_PARENT(TraceGradSens, TraceInfo); ~TraceGradSens() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo { class TraceSpecialize : public TraceInfo { public: - explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } + explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); std::string name() override { return full_name_ + counter_; } std::string symbol() override { return counter_ + "_"; } @@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo { class TraceGradOperation : public TraceInfo { public: - explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {} + explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {} MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); ~TraceGradOperation() override = default; TraceInfoPtr clone() override { @@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo { class TraceForceBool : public TraceInfo { public: - explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {} + explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {} MS_DECLARE_PARENT(TraceForceBool, TraceInfo); ~TraceForceBool() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo { class TraceExpandJ : public TraceInfo { public: - explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {} + explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {} MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); ~TraceExpandJ() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo { class TraceGenMetaFuncGraph : public TraceInfo { public: - explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {} + explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {} MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); ~TraceGenMetaFuncGraph() override = default; TraceInfoPtr clone() override { @@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo { class TraceEvaluatorGenGraph : public TraceInfo { public: - explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {} + explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {} MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); ~TraceEvaluatorGenGraph() override = default; TraceInfoPtr clone() override { @@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo { class TraceResolve : public TraceInfo { public: - explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {} + explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {} MS_DECLARE_PARENT(TraceResolve, TraceInfo); ~TraceResolve() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo { class TraceTransform : public TraceInfo { public: TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } - explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") { + explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") { transform_name_ = transform_name; } @@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo { class TraceGenerateVarArg : public TraceInfo { public: - explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {} + explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {} MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); ~TraceGenerateVarArg() override = default; TraceInfoPtr clone() override { @@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo { class TraceGenerateKwArg : public TraceInfo { public: - explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {} + explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {} MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); ~TraceGenerateKwArg() override = default; TraceInfoPtr clone() override { @@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo { class TraceTrasformK : public TraceInfo { public: - explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {} + explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {} MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); ~TraceTrasformK() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo { class TracePartialTransform : public TraceInfo { public: - explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {} + explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {} MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); ~TracePartialTransform() override = default; TraceInfoPtr clone() override { @@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo { class TraceGetEnv : public TraceInfo { public: - explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {} + explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {} MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); ~TraceGetEnv() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo { class TraceDoSignature : public TraceInfo { public: - explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {} + explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {} MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); ~TraceDoSignature() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo { class TraceCombileLikeGraphs : public TraceInfo { public: TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} - explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {} + explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {} MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); ~TraceCombileLikeGraphs() override = default; TraceInfoPtr clone() override { diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc index 2c38e4290d0..69c6dca5760 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc +++ b/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace device { namespace ascend { -size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { +size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { if (has_malloc_) { MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; } @@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { return size; } -bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) { +bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { MS_EXCEPTION_IF_NULL(addr); has_malloc_ = false; free_mem_size_ = total_mem_size_; @@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const { size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } -void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) { +void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { MS_EXCEPTION_IF_NULL(device_mem_pool_base); device_mem_pool_base_ = device_mem_pool_base; } diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h index a02bd453b2c..7fa3ebc23e8 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h +++ b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h @@ -26,12 +26,12 @@ namespace ascend { class AscendMemoryPool : public DynamicMemPoolBestFit { public: ~AscendMemoryPool() override = default; - AscendMemoryPool(const AscendMemoryPool&) = delete; - AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; + AscendMemoryPool(const AscendMemoryPool &) = delete; + AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; - size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; - bool FreeDeviceMem(const DeviceMemPtr& addr) override; - void set_device_mem_pool_base(uint8_t* device_mem_pool_base); + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; + void set_device_mem_pool_base(uint8_t *device_mem_pool_base); void set_device_mem_pool_size(uint64_t device_mem_pool_size) { device_mem_pool_size_ = device_mem_pool_size; free_mem_size_ = device_mem_pool_size_; @@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { size_t free_mem_size() override; size_t total_mem_size() override; - static AscendMemoryPool& GetInstance() { + static AscendMemoryPool &GetInstance() { static AscendMemoryPool instance; return instance; } @@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { private: AscendMemoryPool() = default; bool has_malloc_{false}; - uint8_t* device_mem_pool_base_{nullptr}; + uint8_t *device_mem_pool_base_{nullptr}; uint64_t device_mem_pool_size_{0}; size_t free_mem_size_{0}; size_t total_mem_size_{0}; diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h index f7804a8ee77..9f4ea4d6676 100755 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h @@ -39,13 +39,13 @@ using std::vector; class AscendStreamAssign { public: - static AscendStreamAssign& GetInstance() { + static AscendStreamAssign &GetInstance() { static AscendStreamAssign instance; // Guaranteed to be destroyed. return instance; } - AscendStreamAssign(const AscendStreamAssign&) = delete; - AscendStreamAssign& operator=(const AscendStreamAssign&) = delete; + AscendStreamAssign(const AscendStreamAssign &) = delete; + AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; uint32_t GetTotalStreamNum() const; // new stream policy @@ -53,19 +53,19 @@ class AscendStreamAssign { uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } uint32_t total_event_num() const { return total_event_num_; } - void InsertActiveNew(const std::shared_ptr& graph_ptr); - void AssignAllNodesStream(const std::shared_ptr& graph_ptr); + void InsertActiveNew(const std::shared_ptr &graph_ptr); + void AssignAllNodesStream(const std::shared_ptr &graph_ptr); void ResetNew(); - void AssignStreamNew(const std::shared_ptr& graph_ptr); - bool IsIndependentNode(const CNodePtr& node_ptr); - const std::unordered_map& logic_to_independent_map() { return logic_to_independent_map_; } - const std::unordered_map& logic_to_physic_map() { return logic_to_physic_map_; } - const std::vector>& inner_parallel_streams() { return inner_parallel_streams_; } - void GetWaitStreams(vector* wait_active_stream_list); - const std::vector& hcom_streams() { return hcom_stream_list_; } - CNodePtr CreateSendApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, + void AssignStreamNew(const std::shared_ptr &graph_ptr); + bool IsIndependentNode(const CNodePtr &node_ptr); + const std::unordered_map &logic_to_independent_map() { return logic_to_independent_map_; } + const std::unordered_map &logic_to_physic_map() { return logic_to_physic_map_; } + const std::vector> &inner_parallel_streams() { return inner_parallel_streams_; } + void GetWaitStreams(vector *wait_active_stream_list); + const std::vector &hcom_streams() { return hcom_stream_list_; } + CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id); - CNodePtr CreateRecvApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, + CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id); private: @@ -73,30 +73,30 @@ class AscendStreamAssign { ~AscendStreamAssign() = default; vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, - const CNodePtr& node); + const CNodePtr &node); - bool IsHcom(const CNodePtr& apply_kernel); + bool IsHcom(const CNodePtr &apply_kernel); bool IsProcessed(uint32_t logic_id); - void TransLogicToPhysic(const vector& logic_ids, vector* physic_ids); - void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, - uint32_t* cur_stream_id); + void TransLogicToPhysic(const vector &logic_ids, vector *physic_ids); + void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, + uint32_t *cur_stream_id); void RecordIdMap(uint32_t logic_id, uint32_t physic_id); - void UpdateStreamActive(const CNodePtr& active_ptr); - void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr); + void UpdateStreamActive(const CNodePtr &active_ptr); + void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr); bool IsTaskSink(); - void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); - void UpdateStreamId(const std::shared_ptr& graph_ptr); - void UpdateEventId(const std::shared_ptr& graph_ptr); - void PrintGraphExeOrders(const std::shared_ptr& graph_ptr); - void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); - uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); + void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id); + void UpdateStreamId(const std::shared_ptr &graph_ptr); + void UpdateEventId(const std::shared_ptr &graph_ptr); + void PrintGraphExeOrders(const std::shared_ptr &graph_ptr); + void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); + uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr); void SetCommonStreamNum(uint32_t cur_stream_id); - void FindAllReduceParallel(const std::shared_ptr& graph_ptr); + void FindAllReduceParallel(const std::shared_ptr &graph_ptr); bool IsProcessedParallelStream(uint32_t stream_id); - void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector* parallel_streams); - void InsertSendRecvForIndependent(const std::shared_ptr& graph_ptr); - void InsertSendRecvForHcomParallel(const std::shared_ptr& graph_ptr); - void GetNeedActiveStreams(const std::shared_ptr& graph_ptr); + void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); + void InsertSendRecvForIndependent(const std::shared_ptr &graph_ptr); + void InsertSendRecvForHcomParallel(const std::shared_ptr &graph_ptr); + void GetNeedActiveStreams(const std::shared_ptr &graph_ptr); uint32_t total_common_stream_num_{0}; uint32_t total_independ_stream_num_{0}; diff --git a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h b/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h index 668b54b78cd..bf4977bf9ae 100644 --- a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h +++ b/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h @@ -28,14 +28,14 @@ namespace device { namespace ascend { class PluginImpl : public PluginIntf { public: - explicit PluginImpl(const std::string& module); + explicit PluginImpl(const std::string &module); ~PluginImpl() override = default; - int Init(const Reporter* reporter) override; + int Init(const Reporter *reporter) override; int UnInit() override; - static Reporter* GetPluginReporter() { return reporter_; } + static Reporter *GetPluginReporter() { return reporter_; } private: - static Reporter* reporter_; + static Reporter *reporter_; std::string module_; }; } // namespace ascend diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc index 3a1dc4689bb..cbecb3030d5 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc @@ -20,12 +20,12 @@ namespace mindspore { namespace device { namespace ascend { -PluginIntf* ProfilingEngineImpl::CreatePlugin() { +PluginIntf *ProfilingEngineImpl::CreatePlugin() { MS_LOG(INFO) << "Create Plugin."; return new (std::nothrow) PluginImpl("Framework"); } -int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) { +int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { if (plugin != nullptr) { delete plugin; } diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h index e8dbfc7087c..c7cbc4b7dd4 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h @@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf { ProfilingEngineImpl() = default; ~ProfilingEngineImpl() override = default; - PluginIntf* CreatePlugin() override; - int ReleasePlugin(PluginIntf* plugin) override; + PluginIntf *CreatePlugin() override; + int ReleasePlugin(PluginIntf *plugin) override; }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc index 29193e5cfa9..c3f622ffee2 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc @@ -35,7 +35,7 @@ using Json = nlohmann::json; namespace mindspore { namespace device { namespace ascend { -ProfilingManager& ProfilingManager::GetInstance() { +ProfilingManager &ProfilingManager::GetInstance() { static ProfilingManager inst; return inst; } @@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { } uint64_t ProfilingManager::GetJobId() const { - const char* job_id = std::getenv("JOB_ID"); + const char *job_id = std::getenv("JOB_ID"); return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); } -bool ProfilingManager::ReportProfilingData(const map& op_taskId_map) const { +bool ProfilingManager::ReportProfilingData(const map &op_taskId_map) const { if (!IsProfiling()) { MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; return false; @@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map& op_taskI MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); Msprof::Engine::ReporterData reporter_data = {}; - for (const auto& iter : op_taskId_map) { + for (const auto &iter : op_taskId_map) { auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; reporter_data.deviceId = UintToInt(device_id_); - reporter_data.data = (unsigned char*)(const_cast(data.c_str())); + reporter_data.data = (unsigned char *)(const_cast(data.c_str())); reporter_data.dataLen = data.size(); auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); if (ret != 0) { @@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map& op_taskI return true; } -static std::vector Split(const std::string& str, const char delim) { +static std::vector Split(const std::string &str, const char delim) { std::vector elems; if (str.empty()) { @@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { device_id_ = device_id; // exp: export PROFILING_MODE=true // export PROFILING_OPTIONS=training_trace - const char* prof_options_str = std::getenv("PROFILING_OPTIONS"); + const char *prof_options_str = std::getenv("PROFILING_OPTIONS"); // register Framework to profiling int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); if (result != 0) { @@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const { MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; return true; } - Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter(); + Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); if (reporter != nullptr) { MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); } diff --git a/mindspore/ccsrc/device/gpu/blocking_queue.h b/mindspore/ccsrc/device/gpu/blocking_queue.h index ccf481858f7..a1594c21a97 100644 --- a/mindspore/ccsrc/device/gpu/blocking_queue.h +++ b/mindspore/ccsrc/device/gpu/blocking_queue.h @@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST, class GpuQueue { public: - GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity); + GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity); virtual ~GpuQueue(); - void RegisterRelease(const std::function& func) { host_release_ = func; } + void RegisterRelease(const std::function &func) { host_release_ = func; } inline bool IsEmpty() const { return head_ == tail_; } inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); } - BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size); - BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const; + BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size); + BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const; BlockQueueStatus_T Pop(); bool Destroy(); private: struct NodeInfo { std::unique_ptr event_; - void* host_feature_addr_; - void* host_label_addr_; + void *host_feature_addr_; + void *host_label_addr_; }; - void* buffer_; + void *buffer_; size_t head_; size_t tail_; size_t feature_size_; @@ -61,10 +61,10 @@ class GpuQueue { size_t capacity_; cudaStream_t stream_; std::unique_ptr node_info_; - std::function host_release_; + std::function host_release_; - GpuQueue(const GpuQueue&) = delete; - GpuQueue& operator=(const GpuQueue&) = delete; + GpuQueue(const GpuQueue &) = delete; + GpuQueue &operator=(const GpuQueue &) = delete; }; class BlockingQueue { @@ -72,11 +72,11 @@ class BlockingQueue { BlockingQueue() : queue_(nullptr) {} ~BlockingQueue() = default; - BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity); - void RegisterRelease(const std::function& func); - BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size, + BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity); + void RegisterRelease(const std::function &func); + BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size, unsigned int timeout_in_sec); - BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size); + BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size); BlockQueueStatus_T Pop(); bool Destroy(); diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc b/mindspore/ccsrc/device/gpu/distribution/collective_init.cc index d212c56ae73..d7ab95bbe84 100644 --- a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc +++ b/mindspore/ccsrc/device/gpu/distribution/collective_init.cc @@ -20,17 +20,17 @@ namespace mindspore { namespace device { namespace gpu { -CollectiveInitializer& CollectiveInitializer::instance() { +CollectiveInitializer &CollectiveInitializer::instance() { static CollectiveInitializer instance = {}; return instance; } bool CollectiveInitializer::collective_inited() const { return collective_inited_; } -const void* CollectiveInitializer::collective_handle() const { return collective_handle_; } +const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } void CollectiveInitializer::InitCollective() { - void* handle = dlopen("libgpu_collective.so", RTLD_LAZY); + void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); if (handle == nullptr) { MS_LOG(EXCEPTION) << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc index b25ba2906b0..e505fdc218c 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc +++ b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc @@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() { CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); } -bool GPUDeviceManager::CreateStream(DeviceStream* stream) { +bool GPUDeviceManager::CreateStream(DeviceStream *stream) { CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); gpu_streams_.emplace_back(*stream); return true; } -const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; } +const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } @@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } -const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } +const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } -const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } +const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } -bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); } +bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } -bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const { +bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { return CudaDriver::CopyDeviceMemToHost(dst, src, size); } -bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const { +bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { return CudaDriver::CopyHostMemToDevice(dst, src, size); } } // namespace gpu diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/device/gpu/gpu_device_manager.h index 3b3d2aecb50..a546b999a43 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.h +++ b/mindspore/ccsrc/device/gpu/gpu_device_manager.h @@ -37,17 +37,17 @@ class GPUDeviceManager { uint32_t cur_device_id() const; bool is_device_id_init() const; - bool CreateStream(DeviceStream* stream); - bool SyncStream(const DeviceStream& stream) const; - const DeviceStream& default_stream() const; + bool CreateStream(DeviceStream *stream); + bool SyncStream(const DeviceStream &stream) const; + const DeviceStream &default_stream() const; - const cudnnHandle_t& GetCudnnHandle() const; - const cublasHandle_t& GetCublasHandle() const; + const cudnnHandle_t &GetCudnnHandle() const; + const cublasHandle_t &GetCublasHandle() const; - bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const; - bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const; + bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; + bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; - static GPUDeviceManager& GetInstance() { + static GPUDeviceManager &GetInstance() { static GPUDeviceManager instance; return instance; } @@ -55,8 +55,8 @@ class GPUDeviceManager { private: GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} ~GPUDeviceManager() = default; - GPUDeviceManager(const GPUDeviceManager&) = delete; - GPUDeviceManager& operator=(const GPUDeviceManager&) = delete; + GPUDeviceManager(const GPUDeviceManager &) = delete; + GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; // default CUDA stream used for all the kernels. DeviceStream default_stream_{nullptr}; diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc index cbd43645abf..3a1a53c6009 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc +++ b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc @@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() { return true; } -bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) { +bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { auto alloc_size = AllocDeviceMem(size, addr); buffer_q_addr_ = *addr; // Buffer queue needs to ensure that the alloc_size and size is equal. return (alloc_size == size) ? true : false; } -size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { +size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { if (size == 0) { MS_LOG(EXCEPTION) << "The memory alloc size is 0."; } @@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { return alloc_size; } -bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); } +bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); } diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h index 0d2f0f8a39a..36374bfaad9 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h +++ b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h @@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit { ~GPUMemoryAllocator() override = default; bool Init(); bool Finalize(); - bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr); + bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); - size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; - bool FreeDeviceMem(const DeviceMemPtr& addr) override; + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; size_t free_mem_size() override; size_t total_mem_size() override; - static GPUMemoryAllocator& GetInstance() { + static GPUMemoryAllocator &GetInstance() { static GPUMemoryAllocator instance; return instance; } private: GPUMemoryAllocator() = default; - GPUMemoryAllocator(const GPUMemoryAllocator&) = delete; - GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete; + GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; + GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; // Used to track address of data buffer queue. DeviceMemPtr buffer_q_addr_{nullptr}; diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc index 05ecf380d11..6ccb4c8cde1 100644 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc @@ -33,8 +33,8 @@ namespace gpu { using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; using mindspore::kernel::KernelBuildInfo; namespace { -bool CheckKernelInfo(const std::shared_ptr& alternative_kernel_info, - const std::shared_ptr& selected_kernel_info) { +bool CheckKernelInfo(const std::shared_ptr &alternative_kernel_info, + const std::shared_ptr &selected_kernel_info) { MS_EXCEPTION_IF_NULL(selected_kernel_info); MS_EXCEPTION_IF_NULL(alternative_kernel_info); size_t selected_input_num = selected_kernel_info->GetInputNum(); @@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr& alternative_kernel_ return true; } -std::string SupportedTypeList(const CNodePtr& kernel_node) { +std::string SupportedTypeList(const CNodePtr &kernel_node) { std::string supported_type_lists = kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); if (!supported_type_lists.empty()) { @@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) { return supported_type_lists; } -bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr& selected_kernel_info) { +bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr &selected_kernel_info) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(selected_kernel_info); std::vector> kernel_info_list; @@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr& alternative_kernel_info) { + [&](const std::shared_ptr &alternative_kernel_info) { return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); }); if (!match) { @@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptrinput(input_index + 1); @@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co } } // namespace -void SetKernelInfo(const CNodePtr& kernel_node) { +void SetKernelInfo(const CNodePtr &kernel_node) { std::vector inputs_format; std::vector inputs_type; std::shared_ptr builder = diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/device/gpu/kernel_info_setter.h index e3dc2241a9a..b351f74fa33 100644 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.h +++ b/mindspore/ccsrc/device/gpu/kernel_info_setter.h @@ -27,7 +27,7 @@ namespace mindspore { namespace device { namespace gpu { -void SetKernelInfo(const CNodePtr& apply_kernel_ptr); +void SetKernelInfo(const CNodePtr &apply_kernel_ptr); class KernelAttr { public: @@ -35,24 +35,24 @@ class KernelAttr { KernelAttr() : all_same_(false) {} ~KernelAttr() = default; - KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { + KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { input_type_.emplace_back(ms_type, format); return *this; } - KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { + KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { output_type_.emplace_back(ms_type, format); return *this; } - KernelAttr& AddAllSameAttr(const bool& all_same) { + KernelAttr &AddAllSameAttr(const bool &all_same) { all_same_ = all_same; return *this; } - const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; } - const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; } - const bool& GetAllSame() const { return all_same_; } + const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } + const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } + const bool &GetAllSame() const { return all_same_; } size_t GetInputSize() const { return input_type_.size(); } size_t GetOutputSize() const { return output_type_.size(); } diff --git a/mindspore/ccsrc/gvar/typeid_manager.cc b/mindspore/ccsrc/gvar/typeid_manager.cc index 97250a65710..f40052411ab 100644 --- a/mindspore/ccsrc/gvar/typeid_manager.cc +++ b/mindspore/ccsrc/gvar/typeid_manager.cc @@ -24,7 +24,7 @@ namespace mindspore { -struct TypeIdManager* TypeIdManager::Get() { +struct TypeIdManager *TypeIdManager::Get() { static TypeIdManager manager; return &manager; } diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc index 658fb578b78..dd86e467134 100644 --- a/mindspore/ccsrc/ir/anf.cc +++ b/mindspore/ccsrc/ir/anf.cc @@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } std::string AnfNode::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } -CNode::CNode(const std::vector& inputs, const FuncGraphPtr& func_graph) +CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} // Check if CNode is an apply with the specific Primitive. -bool CNode::IsApply(const PrimitivePtr& value) const { +bool CNode::IsApply(const PrimitivePtr &value) const { if (value == nullptr) { return false; } @@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const { return false; } -void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; } +void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } std::string CNode::DebugString(int recursive_level) const { std::ostringstream buffer; @@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const { buffer << ToString() << "{"; bool is_first_node = true; int idx = 0; - for (auto& node : inputs_) { + for (auto &node : inputs_) { MS_EXCEPTION_IF_NULL(node); if (is_first_node) { is_first_node = false; @@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const { return buffer.str(); } -OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) { +OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { if (operator_info_ != nullptr) { MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() << ", using the new one: " << operator_info->name(); @@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() { return fullname_with_scope_; } -void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } -void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } -void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } +void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { +bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if (cnode != nullptr) { @@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { return false; } -PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) { +PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { if (node == nullptr) { return nullptr; } @@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) { return ""; } -bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { +bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { if (IsValueNode(node)) { PrimitivePtr fn_value = GetValueNode(node); MS_EXCEPTION_IF_NULL(value); @@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { } namespace id_generator { static std::unordered_map node_ids; -std::string get_id(const AnfNodePtr& node) { +std::string get_id(const AnfNodePtr &node) { auto type_name = node->type_name(); if (node_ids.find(type_name) == node_ids.end()) { node_ids[type_name] = 0; diff --git a/mindspore/ccsrc/ir/base.h b/mindspore/ccsrc/ir/base.h index 6a3537306f6..7ccef138766 100644 --- a/mindspore/ccsrc/ir/base.h +++ b/mindspore/ccsrc/ir/base.h @@ -39,15 +39,15 @@ struct is_shared_ptr> : public std::true_type {}; class Base : public std::enable_shared_from_this { public: constexpr Base() = default; - Base(const Base& other) : std::enable_shared_from_this(other) {} - virtual bool operator==(const Base& rhs) { + Base(const Base &other) : std::enable_shared_from_this(other) {} + virtual bool operator==(const Base &rhs) { if (this == &rhs) { return true; } return false; } - virtual Base& operator=(const Base&) { return *this; } + virtual Base &operator=(const Base &) { return *this; } virtual ~Base() = default; virtual std::size_t hash() const { return tid(); } virtual std::string ToString() const { return type_name(); } @@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this { virtual const bool IsFromTypeId(uint32_t tid) const; virtual std::string type_name() const { return "Base"; } - static uint32_t GetTypeId(const char* const type_key); + static uint32_t GetTypeId(const char *const type_key); virtual uint32_t tid() const { static const uint32_t tid = GetTypeId(typeid(Base).name()); return tid; } template ::value && std::is_base_of::value, T>::type* = nullptr> + typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> inline bool isa() const { static const uint32_t tid = GetTypeId(typeid(T).name()); return this->IsFromTypeId(tid); @@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr; using BaseWeakPtr = std::weak_ptr; template -inline T* cast(U* source) { +inline T *cast(U *source) { if (source != nullptr && source->template isa()) { - return static_cast(source); + return static_cast(source); } else { return nullptr; } @@ -100,7 +100,7 @@ inline T* cast(U* source) { template < typename T, typename U, - typename std::enable_if::value && std::is_base_of::value, T>::type* = nullptr> + typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> inline std::shared_ptr dyn_cast(const std::shared_ptr r) { if (r != nullptr && r->template isa()) { return std::static_pointer_cast(r); @@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager { std::mutex mutex; std::atomic type_counter{0}; std::unordered_map map; - static TypeIdManager* Get(); + static TypeIdManager *Get(); TypeIdManager() : mutex(), type_counter(0), map() {} }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/ccsrc/ir/dtype.cc index 65a42bc3fa6..a6ef99177c4 100644 --- a/mindspore/ccsrc/ir/dtype.cc +++ b/mindspore/ccsrc/ir/dtype.cc @@ -48,11 +48,11 @@ std::string Keyword::ToString() const { return buffer.str(); } -bool Keyword::operator==(const Type& other) const { +bool Keyword::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_keyword = static_cast(other); + const auto &other_keyword = static_cast(other); return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_); } @@ -87,11 +87,11 @@ std::string Slice::ToString() const { return buffer.str(); } -bool Slice::operator==(const Type& other) const { +bool Slice::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_slice = static_cast(other); + auto other_slice = static_cast(other); return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_); } @@ -122,11 +122,11 @@ std::string TensorType::DumpText() const { } } -bool TensorType::operator==(const Type& other) const { +bool TensorType::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_elem_type = static_cast(other).element_type_; + auto other_elem_type = static_cast(other).element_type_; // When element_type_ = nullptr, which means any type of Array. if (element_type_ == nullptr && other_elem_type == nullptr) { return true; @@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) { retval_ = nullptr; } -Function::Function(const std::vector& args, const TypePtr retval) +Function::Function(const std::vector &args, const TypePtr retval) : Object(kObjectTypeFunction, false), args_(args), retval_(retval) {} TypePtr Function::DeepCopy() const { @@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const { TypePtrList args; TypePtr retval = nullptr; (void)std::transform(args_.begin(), args_.end(), std::back_inserter(args), - [](const TypePtr& arg) { return arg->DeepCopy(); }); + [](const TypePtr &arg) { return arg->DeepCopy(); }); if (retval_ != nullptr) { retval = retval_->DeepCopy(); } @@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const { } } -bool Function::operator==(const Type& other) const { +bool Function::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_function = static_cast(other); + const auto &other_function = static_cast(other); if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) { if (*retval_ != *other_function.retval_) { return false; @@ -188,7 +188,7 @@ std::string Function::ToString() const { } else { buffer << "Func[("; bool begin = true; - for (auto& attr : args_) { + for (auto &attr : args_) { if (!begin) { buffer << ", "; } else { @@ -242,34 +242,34 @@ std::string JTagged::DumpText() const { return buffer.str(); } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr problem) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr problem) { MS_EXCEPTION_IF_NULL(problem); os << problem->ToString(); return os; } -std::size_t TypeHasher::operator()(TypePtr const& type) const { +std::size_t TypeHasher::operator()(TypePtr const &type) const { MS_EXCEPTION_IF_NULL(type); std::size_t hash = std::hash()(type->type_id()); return hash; } -std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const { +std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { std::size_t hash_sum = 0; - for (auto& type : type_list) { + for (auto &type : type_list) { auto type_id = static_cast(type->type_id()); hash_sum = hash_combine(hash_sum, type_id); } return hash_sum; } -bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const { +bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t2); return t1->type_id() == t2->type_id(); } -bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const { +bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { if (lhs.size() != rhs.size()) { return false; } @@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) { namespace { template -TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) { +TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { TypePtr type = nullptr; if (type_name == num_type_name) { type = std::make_shared(); @@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_ } auto bits = std::stoi(type_name.substr(num_type_name.size())); type = std::make_shared(bits); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what(); } } return type; } -std::vector StringToVectorOfType(const std::string& type_names) { +std::vector StringToVectorOfType(const std::string &type_names) { std::vector types; if (type_names.length() == 0) { return types; @@ -371,7 +371,7 @@ std::vector StringToVectorOfType(const std::string& type_names) { return types; } -TypePtr TensorStrToType(const std::string& type_name) { +TypePtr TensorStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Tensor") { type = std::make_shared(); @@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) { return nullptr; } type = std::make_shared(element_type); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) { return type; } -TypePtr ListStrToType(const std::string& type_name) { +TypePtr ListStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "List") { type = std::make_shared(); @@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) { std::string element_strs = type_name.substr(start, end - start); std::vector element_types = StringToVectorOfType(element_strs); bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); if (wrong) { return nullptr; } type = std::make_shared(element_types); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) { return type; } -TypePtr TupleStrToType(const std::string& type_name) { +TypePtr TupleStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Tuple") { type = std::make_shared(); @@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) { std::string element_strs = type_name.substr(start, end - start); std::vector element_types = StringToVectorOfType(element_strs); bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); if (wrong) { return nullptr; } type = std::make_shared(element_types); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } return type; } -TypePtr FunctionStrToType(const std::string& type_name) { +TypePtr FunctionStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Function") { @@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) { std::vector args_type = StringToVectorOfType(str_args); TypePtr retval = StringToType(str_retval); - bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; }); + bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); if (retval == nullptr || wrong) { return nullptr; } type = std::make_shared(args_type, retval); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) { } } // namespace -TypePtr StringToType(const std::string& type_name) { +TypePtr StringToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name.compare("None") == 0) { type = std::make_shared(); @@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) { return type; } -bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { if (x == nullptr || base_type == nullptr) { MS_LOG(ERROR) << "Type is nullptr."; return false; @@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { } } -bool IsSubType(TypePtr const& t1, TypePtr const& t2) { +bool IsSubType(TypePtr const &t1, TypePtr const &t2) { MS_EXCEPTION_IF_NULL(t1); if (t1->type_id() == kTypeUnknown) { return false; @@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) { } REGISTER_PYBIND_DEFINE( - typing, ([](py::module* const m) { + typing, ([](py::module *const m) { auto m_sub = m->def_submodule("typing", "submodule for dtype"); py::enum_(m_sub, "TypeId"); (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); (void)m_sub.def("load_type", &TypeIdToType, "load type"); (void)m_sub.def( - "dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type"); + "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); (void)py::class_>(m_sub, "Type") .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) .def("__eq__", - [](const TypePtr& t1, const TypePtr& t2) { + [](const TypePtr &t1, const TypePtr &t2) { if (t1 != nullptr && t2 != nullptr) { return *t1 == *t2; } @@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE( .def("__hash__", &Type::hash) .def("__str__", &Type::ToString) .def("__repr__", &Type::ReprString) - .def("__deepcopy__", [](const TypePtr& t, py::dict) { + .def("__deepcopy__", [](const TypePtr &t, py::dict) { if (t == nullptr) { return static_cast(nullptr); } @@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(m_sub, "Bool") .def(py::init()) .def(py::pickle( - [](const Bool&) { // __getstate__ + [](const Bool &) { // __getstate__ return py::make_tuple(); }, - [](const py::tuple&) { // __setstate__ + [](const py::tuple &) { // __setstate__ return std::make_shared(); })); (void)py::class_>(m_sub, "Int") .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const Int& t) { // __getstate__ + [](const Int &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const UInt& t) { // __getstate__ + [](const UInt &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const Float& t) { // __getstate__ + [](const Float &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init(), py::arg("element")) .def("element_type", &TensorType::element) .def(py::pickle( - [](const TensorType& t) { // __getstate__ + [](const TensorType &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } diff --git a/mindspore/ccsrc/ir/dtype.h b/mindspore/ccsrc/ir/dtype.h index e3e2099b5ef..cefdf420994 100644 --- a/mindspore/ccsrc/ir/dtype.h +++ b/mindspore/ccsrc/ir/dtype.h @@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr; class Keyword : public Object { public: Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} - Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} + Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} ~Keyword() override = default; MS_DECLARE_PARENT(Keyword, Object) @@ -70,7 +70,7 @@ class Keyword : public Object { std::string ToString() const override; std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::string GetKey() const { return key_; } TypePtr GetValue() const { return value_; } @@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr; class Slice : public Object { public: Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} - Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step) + Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step) : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} ~Slice() override = default; @@ -95,7 +95,7 @@ class Slice : public Object { std::string ToString() const override; std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr get_start() const { return start_; } TypePtr get_stop() const { return stop_; } @@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr; class TensorType : public Object { public: TensorType() : Object(kObjectTypeTensorType) {} - explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} + explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} ~TensorType() override = default; MS_DECLARE_PARENT(TensorType, Object) TypeId generic_type_id() const override { return kObjectTypeTensorType; } const TypePtr element() const { return element_type_; } - void set_element(const TypePtr& element_type) { element_type_ = element_type; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } TypePtr DeepCopy() const override; std::string ToString() const override; std::string ToReprString() const override { return "tensor"; } std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; private: TypePtr element_type_; @@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr; class Function : public Object { public: Function(); - Function(const std::vector& args, const TypePtr retval); + Function(const std::vector &args, const TypePtr retval); ~Function() override = default; MS_DECLARE_PARENT(Function, Object) @@ -141,11 +141,11 @@ class Function : public Object { // Add temporarily for return abstraction to avoid type checking. bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } - const std::vector& args() const { return args_; } - const TypePtr& retval() const { return retval_; } + const std::vector &args() const { return args_; } + const TypePtr &retval() const { return retval_; } TypePtr DeepCopy() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::string ToString() const override; std::string ToReprString() const override { return "function"; } @@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr; class JTagged : public Object { public: JTagged() : Object(kObjectTypeJTagged) {} - explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} + explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} ~JTagged() override = default; MS_DECLARE_PARENT(JTagged, Object) @@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr; class Problem : public Type { public: Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} - explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {} + explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {} ~Problem() override = default; MS_DECLARE_PARENT(Problem, Type) @@ -222,7 +222,7 @@ class Problem : public Type { std::string ToString() const override { return kind_.name(); } std::string DumpText() const override { return "ProblemType"; } - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr problem); + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr problem); private: Named kind_; @@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr; // helper template template -TypePtr Clone(const T& t) { +TypePtr Clone(const T &t) { return t.Clone(); } -TypePtr StringToType(const std::string& type_name); +TypePtr StringToType(const std::string &type_name); // Judge whether x is predicate or is a subclass of predicate. -bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type); +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); // Whether t1 is identity or a subclass of t2. -bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr); +bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); struct TypeHasher { - std::size_t operator()(TypePtr const& type) const; + std::size_t operator()(TypePtr const &type) const; }; struct TypeListHasher { - std::size_t operator()(const TypePtrList& type_list) const; + std::size_t operator()(const TypePtrList &type_list) const; }; struct TypeEqual { - bool operator()(TypePtr const& t1, TypePtr const& t2) const; + bool operator()(TypePtr const &t1, TypePtr const &t2) const; }; struct TypeListEqual { - bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const; + bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const; }; extern const TypePtr kTypeExternal; diff --git a/mindspore/ccsrc/ir/dtype/container.cc b/mindspore/ccsrc/ir/dtype/container.cc index 8bca29f7935..3f8244c2e38 100644 --- a/mindspore/ccsrc/ir/dtype/container.cc +++ b/mindspore/ccsrc/ir/dtype/container.cc @@ -24,7 +24,7 @@ #include "pybind_api/export_flags.h" namespace mindspore { -static std::string DumpTypeVector(const std::vector& elements, bool is_dumptext) { +static std::string DumpTypeVector(const std::vector &elements, bool is_dumptext) { std::ostringstream oss; bool begin = true; int cnt = 0; @@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const { } else { TypePtrList elements; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), - [](const TypePtr& ele) { return ele->DeepCopy(); }); + [](const TypePtr &ele) { return ele->DeepCopy(); }); auto copy = std::make_shared(elements); return copy; } @@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const { return elements_[dim]; } -bool List::operator==(const Type& other) const { +bool List::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const List& other_list = static_cast(other); + const List &other_list = static_cast(other); if (elements_.size() != other_list.elements_.size()) { return false; } @@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const { return true; } -Class::Class(const Named& tag, const ClassAttrVector& attributes, - const std::unordered_map& methods) +Class::Class(const Named &tag, const ClassAttrVector &attributes, + const std::unordered_map &methods) : Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {} std::string List::ToString() const { @@ -122,7 +122,7 @@ std::string List::DumpText() const { return buffer.str(); } -bool Class::operator==(const Type& other) const { +bool Class::operator==(const Type &other) const { // Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj. return &other == this; } @@ -143,7 +143,7 @@ std::string Class::ToString() const { } else { bool begin = true; buffer << "cls." << tag_ << "["; - for (auto& attr : attributes_) { + for (auto &attr : attributes_) { if (!begin) { buffer << ", "; } else { @@ -163,7 +163,7 @@ std::string Class::DumpText() const { } else { bool begin = true; buffer << "Cls." << tag_ << "["; - for (auto& attr : attributes_) { + for (auto &attr : attributes_) { if (!begin) { buffer << ", "; } else { @@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const { } else { TypePtrList elements; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), - [](const TypePtr& ele) { return ele->DeepCopy(); }); + [](const TypePtr &ele) { return ele->DeepCopy(); }); auto copy = std::make_shared(elements); return copy; } } -bool Tuple::operator==(const Type& other) const { +bool Tuple::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_tuple = static_cast(other); + auto other_tuple = static_cast(other); if (elements_.size() != other_tuple.elements_.size()) { return false; } @@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const { std::vector> kv; (void)std::transform( key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const std::pair& item) { return std::make_pair(item.first, item.second->DeepCopy()); }); + [](const std::pair &item) { return std::make_pair(item.first, item.second->DeepCopy()); }); return std::make_shared(kv); } } @@ -259,7 +259,7 @@ std::string Dictionary::ToString() const { std::ostringstream buffer; std::vector keys; std::vector values; - for (const auto& kv : key_values_) { + for (const auto &kv : key_values_) { keys.push_back(kv.first); values.push_back(kv.second); } @@ -276,12 +276,12 @@ std::string Dictionary::ToString() const { std::string Dictionary::DumpText() const { return ToString(); } -bool Dictionary::operator==(const mindspore::Type& other) const { +bool Dictionary::operator==(const mindspore::Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_dict = static_cast(other); + const auto &other_dict = static_cast(other); if (key_values_.size() != other_dict.key_values_.size()) { return false; } diff --git a/mindspore/ccsrc/ir/dtype/container.h b/mindspore/ccsrc/ir/dtype/container.h index 04ed484cf7e..0612d24c4dd 100644 --- a/mindspore/ccsrc/ir/dtype/container.h +++ b/mindspore/ccsrc/ir/dtype/container.h @@ -40,10 +40,10 @@ namespace mindspore { class List : public Object { public: List() : Object(kObjectTypeList) {} - List(const std::initializer_list& objs) + List(const std::initializer_list &objs) : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} // Shadow copy; - explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {} + explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {} ~List() override {} MS_DECLARE_PARENT(List, Object) @@ -51,7 +51,7 @@ class List : public Object { TypeId generic_type_id() const override { return kObjectTypeList; } TypePtr DeepCopy() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::size_t size() const { return elements_.size(); } TypePtrList elements() const { return elements_; } std::string ToString() const override; @@ -68,22 +68,22 @@ using ClassAttrVector = std::vector>; class Class : public Object { public: Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} - Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map& methods); + Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map &methods); ~Class() override {} MS_DECLARE_PARENT(Class, Object) TypeId generic_type_id() const override { return kObjectTypeClass; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override; std::string ToString() const override; std::string DumpText() const override; - void set_value(const std::unordered_map& v) { attributes_value_ = v; } + void set_value(const std::unordered_map &v) { attributes_value_ = v; } Named tag() { return tag_; } std::unordered_map GetValue() { return attributes_value_; } std::unordered_map methods() { return methods_; } - ClassAttrVector& GetAttributes() { return attributes_; } + ClassAttrVector &GetAttributes() { return attributes_; } ClassAttrVector attributes_; @@ -99,11 +99,11 @@ class Tuple : public Object { public: Tuple() : Object(kObjectTypeTuple) {} // usage : Tuple t = {std::make_shared(), std::make_shared(32)}; - Tuple(const std::initializer_list& objs) + Tuple(const std::initializer_list &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} // Shadow copy - explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} + explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} ~Tuple() override {} MS_DECLARE_PARENT(Tuple, Object) @@ -115,7 +115,7 @@ class Tuple : public Object { std::string ToReprString() const override { return "tuple_"; } std::string DumpText() const override; const TypePtr operator[](size_t dim) const; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtrList elements() const { return elements_; } std::size_t size() const { return elements_.size(); } @@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr; class Dictionary : public Object { public: Dictionary() : Object(kObjectTypeDictionary) {} - explicit Dictionary(const std::vector>& key_values) + explicit Dictionary(const std::vector> &key_values) : Object(kObjectTypeDictionary, false), key_values_(key_values) {} ~Dictionary() override {} @@ -136,7 +136,7 @@ class Dictionary : public Object { TypeId generic_type_id() const override { return kObjectTypeDictionary; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override; std::string ToString() const override; std::string DumpText() const override; diff --git a/mindspore/ccsrc/ir/dtype/number.cc b/mindspore/ccsrc/ir/dtype/number.cc index d9ef6bb3bd1..44ac9e8e6aa 100644 --- a/mindspore/ccsrc/ir/dtype/number.cc +++ b/mindspore/ccsrc/ir/dtype/number.cc @@ -24,11 +24,11 @@ #include "pybind_api/export_flags.h" namespace mindspore { -bool Number::operator==(const Type& other) const { +bool Number::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_number = static_cast(other); + auto other_number = static_cast(other); return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_)); } diff --git a/mindspore/ccsrc/ir/dtype/number.h b/mindspore/ccsrc/ir/dtype/number.h index cb3b0a607c8..3930f51d730 100644 --- a/mindspore/ccsrc/ir/dtype/number.h +++ b/mindspore/ccsrc/ir/dtype/number.h @@ -49,12 +49,12 @@ class Number : public Object { TypeId type_id() const override { return number_type_; } TypeId generic_type_id() const override { return kObjectTypeNumber; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override { return std::make_shared(); } std::string ToString() const override { return "Number"; } std::string ToReprString() const override { return "number"; } std::string DumpText() const override { return "Number"; } - std::string GetTypeName(const std::string& type_name) const { + std::string GetTypeName(const std::string &type_name) const { std::ostringstream oss; oss << type_name; if (nbits() != 0) { diff --git a/mindspore/ccsrc/ir/dtype/ref.h b/mindspore/ccsrc/ir/dtype/ref.h index 7f1dc4a95fc..7d8159289f0 100644 --- a/mindspore/ccsrc/ir/dtype/ref.h +++ b/mindspore/ccsrc/ir/dtype/ref.h @@ -51,7 +51,7 @@ class RefKeyType : public Object { class RefType : public Object { public: RefType() : Object(kObjectTypeRef) {} - RefType(const TypePtr& subtype, const TypePtr& subtype_origin) + RefType(const TypePtr &subtype, const TypePtr &subtype_origin) : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} ~RefType() override {} MS_DECLARE_PARENT(RefType, Object) diff --git a/mindspore/ccsrc/ir/dtype/type.cc b/mindspore/ccsrc/ir/dtype/type.cc index 6fbd7f81110..30bf0c8e3fb 100644 --- a/mindspore/ccsrc/ir/dtype/type.cc +++ b/mindspore/ccsrc/ir/dtype/type.cc @@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) { } } -const char* MetaIdLabel(const TypeId& v) { +const char *MetaIdLabel(const TypeId &v) { switch (v) { case kTypeUnknown: return "kTypeUnknown"; @@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) { } } -const char* ObjectIdLabel(const TypeId& v) { +const char *ObjectIdLabel(const TypeId &v) { switch (v) { case kObjectTypeNumber: return "kObjectTypeNumber"; @@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) { } } -const char* NumberIdLabel(const TypeId& v) { +const char *NumberIdLabel(const TypeId &v) { switch (v) { case kNumberTypeBool: return "kNumberTypeBool"; @@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) { } } -const char* TypeIdLabel(const TypeId& v) { +const char *TypeIdLabel(const TypeId &v) { if (v < kMetaTypeEnd) { return MetaIdLabel(v); } else { @@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) { } } -bool IsSameObjectType(const Type& lhs, const Type& rhs) { +bool IsSameObjectType(const Type &lhs, const Type &rhs) { if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) { return false; } return lhs.object_type() == rhs.object_type(); } -size_t GetTypeByte(const TypePtr& type_ptr) { +size_t GetTypeByte(const TypePtr &type_ptr) { if (type_ptr && type_ptr->isa()) { auto number = dyn_cast(type_ptr); if (!number) { @@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) { } } -bool Type::operator==(const Value& other) const { +bool Type::operator==(const Value &other) const { if (other.isa()) { - auto other_type = static_cast(&other); + auto other_type = static_cast(&other); return *this == *other_type; } else { return false; @@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() { return ptr; } -std::ostream& operator<<(std::ostream& os, const Type& type) { +std::ostream &operator<<(std::ostream &os, const Type &type) { os << type.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const TypePtr type) { +std::ostream &operator<<(std::ostream &os, const TypePtr type) { os << type->ToString(); return os; } @@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const { return false; } -std::ostream& operator<<(std::ostream& os, const Object& obj) { +std::ostream &operator<<(std::ostream &os, const Object &obj) { os << obj.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr obj) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr obj) { os << obj->ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const TypePtrList& types) { +std::ostream &operator<<(std::ostream &os, const TypePtrList &types) { os << "["; for (size_t i = 0; i < types.size(); ++i) { if (i > 0) { diff --git a/mindspore/ccsrc/ir/dtype/type.h b/mindspore/ccsrc/ir/dtype/type.h index 9454596538b..0528bccf03a 100644 --- a/mindspore/ccsrc/ir/dtype/type.h +++ b/mindspore/ccsrc/ir/dtype/type.h @@ -95,10 +95,10 @@ enum TypeId : int { TypeId IntBitsToTypeId(const int nbits); TypeId UIntBitsToTypeId(const int nbits); TypeId FloatBitsToTypeId(const int nbits); -const char* TypeIdLabel(const TypeId& v); +const char *TypeIdLabel(const TypeId &v); TypeId NormalizeTypeId(const TypeId type_id); -bool IsSameObjectType(const Type& lhs, const Type& rhs); -size_t GetTypeByte(const TypePtr& type_ptr); +bool IsSameObjectType(const Type &lhs, const Type &rhs); +size_t GetTypeByte(const TypePtr &type_ptr); // Base class for all types // forward declaration. @@ -110,14 +110,14 @@ class Type : public Value { ~Type() override = default; MS_DECLARE_PARENT(Type, Value) - bool operator==(const Value& other) const override; + bool operator==(const Value &other) const override; TypeId meta_type() const { return meta_type_; } virtual TypeId type_id() const { return meta_type_; } virtual TypeId generic_type_id() const { return kMetaTypeType; } - virtual bool operator!=(const Type& other) const { return !(*this == other); } - virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); } + virtual bool operator!=(const Type &other) const { return !(*this == other); } + virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); } virtual bool equal(const TypePtr other) const { return *this == *other; } virtual TypeId object_type() const { return kTypeUnknown; } @@ -134,8 +134,8 @@ class Type : public Value { bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } bool IsGeneric() const { return is_generic_; } abstract::AbstractBasePtr ToAbstract() override; - friend std::ostream& operator<<(std::ostream& os, const Type& type); - friend std::ostream& operator<<(std::ostream& os, const TypePtr type); + friend std::ostream &operator<<(std::ostream &os, const Type &type); + friend std::ostream &operator<<(std::ostream &os, const TypePtr type); const bool parse_info_ = true; @@ -163,14 +163,14 @@ class Object : public Type { bool equal(const TypePtr other) const override; std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } - friend std::ostream& operator<<(std::ostream& os, const Object& obj); - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr obj); + friend std::ostream &operator<<(std::ostream &os, const Object &obj); + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr obj); private: const TypeId object_type_; }; -std::ostream& operator<<(std::ostream& os, const TypePtrList& types); +std::ostream &operator<<(std::ostream &os, const TypePtrList &types); } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 93fd9c09368..8a58f320f13 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -61,7 +61,7 @@ FuncGraph::FuncGraph() AbstractFunctionPtr FuncGraph::abstract() { AbstractBasePtrList args_spec_list; - for (auto& p : parameters_) { + for (auto &p : parameters_) { MS_EXCEPTION_IF_NULL(p); if (p->abstract() == nullptr) { MS_LOG(ERROR) << "Error!!"; @@ -78,7 +78,7 @@ AbstractFunctionPtr FuncGraph::abstract() { return std::make_shared(args_spec_list, output()->abstract()); } -abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr& context) { +abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) { AnalysisContextPtr temp_context = context; if (temp_context == nullptr) { temp_context = abstract::AnalysisContext::DummyContext(); @@ -96,7 +96,7 @@ AnfNodePtr FuncGraph::output() const { } } -void FuncGraph::set_output(const AnfNodePtr& value, bool force_new_ret) { +void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { if (force_new_ret || return_ == nullptr) { std::vector params({NewValueNode(prim::kPrimReturn), value}); FuncGraphPtr this_graph = shared_from_base(); @@ -125,7 +125,7 @@ ParameterPtr FuncGraph::add_parameter() { return p; } -void FuncGraph::add_parameter(const ParameterPtr& p) { +void FuncGraph::add_parameter(const ParameterPtr &p) { if (manager_.lock()) { std::vector new_params = parameters_; new_params.push_back(p); @@ -135,7 +135,7 @@ void FuncGraph::add_parameter(const ParameterPtr& p) { } } -ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { +ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { FuncGraphPtr this_graph = shared_from_base(); ParameterPtr p = std::make_shared(this_graph); p->set_name(name); @@ -154,14 +154,14 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { return p; } -bool FuncGraph::has_flag(const std::string& flag) { +bool FuncGraph::has_flag(const std::string &flag) { if (flags_.count(flag)) { return flags_[flag]; } return false; } -CNodePtr FuncGraph::NewCNode(const std::vector& inputs) { +CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { CNodePtr cnode = std::make_shared(inputs, shared_from_base()); if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { order_.push_back(cnode); @@ -170,7 +170,7 @@ CNodePtr FuncGraph::NewCNode(const std::vector& inputs) { return cnode; } -CNodePtr FuncGraph::NewCNodeWithScope(const std::vector& inputs, const ScopePtr& scope) { +CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { CNodePtr app = NewCNode(inputs); app->set_scope(scope); return app; @@ -178,13 +178,13 @@ CNodePtr FuncGraph::NewCNodeWithScope(const std::vector& inputs, con void FuncGraph::DumpCNodeList() { MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; - for (const auto& cnode : order_) { + for (const auto &cnode : order_) { MS_LOG(INFO) << cnode->DebugString(); } } std::string FuncGraph::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } GraphDebugInfoPtr FuncGraph::debug_info() { @@ -195,38 +195,38 @@ GraphDebugInfoPtr FuncGraph::debug_info() { return this->debug_info_; } -const AnfNodeSet& FuncGraph::nodes() { +const AnfNodeSet &FuncGraph::nodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& nodes = mng->nodes(); + auto &nodes = mng->nodes(); return nodes[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::value_nodes() { +const AnfNodeCounterMap &FuncGraph::value_nodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& cts = mng->valuenodes(); + auto &cts = mng->valuenodes(); return cts[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::free_variables_direct() { +const AnfNodeCounterMap &FuncGraph::free_variables_direct() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& fv_direct = mng->free_variables_direct(); + auto &fv_direct = mng->free_variables_direct(); return fv_direct[shared_from_base()]; } -const BaseRefCounterMap& FuncGraph::free_variables_total() { +const BaseRefCounterMap &FuncGraph::free_variables_total() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& fv_total = mng->free_variables_total(); + auto &fv_total = mng->free_variables_total(); return fv_total[shared_from_base()]; } std::vector FuncGraph::free_variables_nodes() { std::vector nodes; - const auto& fv_total = this->free_variables_total(); - for (auto& p : fv_total) { + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { auto key = p.first; if (utils::isa(key)) { nodes.push_back(utils::cast(key)); @@ -238,8 +238,8 @@ std::vector FuncGraph::free_variables_nodes() { std::vector FuncGraph::free_variables_func_graphs() { std::vector func_graphs; - const auto& fv_total = this->free_variables_total(); - for (auto& p : fv_total) { + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { auto key = p.first; if (utils::isa(key)) { func_graphs.push_back(utils::cast(key)); @@ -249,31 +249,31 @@ std::vector FuncGraph::free_variables_func_graphs() { return func_graphs; } -const FuncGraphCounterMap& FuncGraph::func_graphs_used() { +const FuncGraphCounterMap &FuncGraph::func_graphs_used() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& used = mng->func_graphs_used(); + auto &used = mng->func_graphs_used(); return used[shared_from_base()]; } -const FuncGraphSet& FuncGraph::func_graphs_used_total() { +const FuncGraphSet &FuncGraph::func_graphs_used_total() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& used = mng->func_graphs_used_total(shared_from_base()); + auto &used = mng->func_graphs_used_total(shared_from_base()); return used; } -const FuncGraphCounterMap& FuncGraph::func_graph_users() { +const FuncGraphCounterMap &FuncGraph::func_graph_users() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& users = mng->func_graph_users(); + auto &users = mng->func_graph_users(); return users[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::func_graph_user_cnodes() { +const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& users = mng->func_graph_user_cnodes(); + auto &users = mng->func_graph_user_cnodes(); return users[shared_from_base()]; } @@ -288,13 +288,13 @@ FuncGraphPtr FuncGraph::parent() { return mng->parent(shared_from_base()); } -const FuncGraphSet& FuncGraph::children() { +const FuncGraphSet &FuncGraph::children() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->children(shared_from_base()); } -const FuncGraphSet& FuncGraph::scope() { +const FuncGraphSet &FuncGraph::scope() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->scopes(shared_from_base()); @@ -312,9 +312,9 @@ std::shared_ptr> FuncGraph::recursive_graphs() { return mng->recursive_graphs(shared_from_base()); } -void FuncGraph::DumpFuncGraph(const std::string& path) { draw::Draw(path + ".dot", shared_from_base()); } +void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base()); } -AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { +AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { auto itr = this->parameter_default_value_.find(name); if (itr == parameter_default_value_.end()) { return nullptr; @@ -330,9 +330,9 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { } // set the default values -void FuncGraph::SetDefaultValues(const std::vector& name_list, const std::vector& value_list) { +void FuncGraph::SetDefaultValues(const std::vector &name_list, const std::vector &value_list) { auto all_is_null = std::all_of(value_list.begin(), value_list.end(), - [](const AnfNodePtr& node) { return IsValueNode(node); }); + [](const AnfNodePtr &node) { return IsValueNode(node); }); if (value_list.empty()) { all_is_null = true; } @@ -348,7 +348,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } size_t FuncGraph::GetDefaultValueCount() { int null_count = std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), - [](const std::pair& pair) { return IsValueNode(pair.second); }); + [](const std::pair &pair) { return IsValueNode(pair.second); }); return parameter_default_value_.size() - IntToSize(null_count); } @@ -425,7 +425,7 @@ int FuncGraph::GetPositionalArgsCount() const { return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); } -AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { +AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { for (size_t i = 0; i < parameters_.size(); ++i) { MS_EXCEPTION_IF_NULL(parameters_[i]); auto param_cast = parameters_[i]->cast(); @@ -437,9 +437,9 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { return nullptr; } -void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, - std::vector* specialized_parameter_list, - std::unordered_map* repl_nodes, int variable_args_count, +void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + std::unordered_map *repl_nodes, int variable_args_count, int pos_args_input_count) { // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple if (specialized_graph->has_vararg()) { @@ -472,14 +472,14 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, } } -void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, - std::vector* specialized_parameter_list, - const std::vector& kwarg_list, - std::unordered_map* repl_nodes) { +void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + const std::vector &kwarg_list, + std::unordered_map *repl_nodes) { std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; - for (const auto& kwarg : kwarg_list) { + for (const auto &kwarg : kwarg_list) { MS_EXCEPTION_IF_NULL(kwarg); std::string kw_param_name = kwarg->get_key(); MS_EXCEPTION_IF_NULL(specialized_graph); @@ -493,7 +493,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; MS_EXCEPTION_IF_NULL(specialized_parameter_list); auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), - [param_name](const AnfNodePtr& node) { + [param_name](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto param = node->cast(); return param != nullptr && param->name() == param_name; @@ -526,10 +526,10 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); } -void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, - std::unordered_map* repl_nodes, - const std::vector& kwarg_keys_tuple_nodes, - const std::vector& kwarg_values_tuple_nodes) { +void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, + std::unordered_map *repl_nodes, + const std::vector &kwarg_keys_tuple_nodes, + const std::vector &kwarg_values_tuple_nodes) { if (has_kwarg()) { MS_EXCEPTION_IF_NULL(specialized_graph); TraceManager::DebugTrace( @@ -544,7 +544,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, } } -bool FuncGraph::NeedGenerate(const std::vector& kwarg_list) { +bool FuncGraph::NeedGenerate(const std::vector &kwarg_list) { // if the function does not have any vararg/kwarg/kwonly/default value/kw args input // return the original graph if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { @@ -558,9 +558,9 @@ bool FuncGraph::NeedGenerate(const std::vector& return true; } -void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, - const std::vector& specialized_parameter_list, - std::unordered_map* repl_nodes) { +void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, + const std::vector &specialized_parameter_list, + std::unordered_map *repl_nodes) { MS_EXCEPTION_IF_NULL(specialized_graph); for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { auto param_node = specialized_graph->parameters()[i]; @@ -583,10 +583,10 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, } } -FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { std::vector kwarg_list; size_t arguments_count = args_spec_list.size(); - for (const auto& arg : args_spec_list) { + for (const auto &arg : args_spec_list) { // if it is a keyword argument MS_EXCEPTION_IF_NULL(arg); if (arg->isa()) { @@ -619,11 +619,11 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) MS_EXCEPTION_IF_NULL(specialized_graph); auto params = specialized_graph->parameters(); (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), - std::back_inserter(specialized_parameter_list), [](const AnfNodePtr& node) { return node; }); + std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); std::shared_ptr manager = mindspore::Manage(specialized_graph, false); auto tr = manager->Transact(); - for (auto& node_pair : repl_nodes) { + for (auto &node_pair : repl_nodes) { MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" << node_pair.second->DebugString(); (void)tr.Replace(node_pair.first, node_pair.second); @@ -638,7 +638,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) return specialized_graph; } -void FuncGraph::add_parameter_obj_node(const AnfNodePtr& p) { paramter_obj_nodes_.push_back(p); } +void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } std::list FuncGraph::GetOrderedCnodes() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { @@ -651,7 +651,7 @@ std::list FuncGraph::GetOrderedCnodes() { std::list cnodes; auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); - for (const auto& node : nodes) { + for (const auto &node : nodes) { auto cnode = dyn_cast(node); if (cnode) { cnodes.push_back(cnode); @@ -679,7 +679,7 @@ void FuncGraph::EraseUnusedNodeInOrder() { } } -void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr& n) { +void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { order_.remove(n->cast()); MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; @@ -690,7 +690,7 @@ void FuncGraph::CheckOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { MS_LOG(DEBUG) << "Check graph " << ToString(); for (auto it = order_.begin(); it != order_.end(); (void)it++) { - for (const auto& input_node : (*it)->inputs()) { + for (const auto &input_node : (*it)->inputs()) { if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { // Need to reorder the wrong order node. auto found = std::find(order_.begin(), it, input_node); @@ -705,7 +705,7 @@ void FuncGraph::CheckOrder() { } auto mng = manager_.lock(); if (mng != nullptr) { - const auto& nodes = mng->nodes()[shared_from_base()]; + const auto &nodes = mng->nodes()[shared_from_base()]; if (nodes.size() != (order_.size() + parameters_.size())) { DumpCNodeList(); MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " @@ -718,7 +718,7 @@ void FuncGraph::CheckOrder() { const char kPrimHasEffect[] = "_side_effect_flag"; -bool FuncGraph::HasEffect(const CNodePtr& cnode) { +bool FuncGraph::HasEffect(const CNodePtr &cnode) { auto prim = GetCNodePrimitive(cnode); if (prim != nullptr && prim->isa()) { auto do_sig = prim->cast(); @@ -739,9 +739,9 @@ bool FuncGraph::HasEffect(const CNodePtr& cnode) { return false; } -std::shared_ptr> FindRoots(const std::vector& segment) { +std::shared_ptr> FindRoots(const std::vector &segment) { std::shared_ptr> roots = std::make_shared>(segment); - for (const auto& node : segment) { + for (const auto &node : segment) { if (roots->size() == 1) { return roots; } @@ -757,9 +757,9 @@ std::shared_ptr> FindRoots(const std::vector& seg return roots; } -std::shared_ptr> FindLeaves(const std::vector& segment) { +std::shared_ptr> FindLeaves(const std::vector &segment) { std::shared_ptr> nodes = std::make_shared>(segment); - for (const auto& node : segment) { + for (const auto &node : segment) { if (nodes->size() == 1) { return nodes; } @@ -790,7 +790,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { std::list depends_order; std::vector segment; - for (const auto& cnode : order_) { + for (const auto &cnode : order_) { if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { continue; } @@ -830,7 +830,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { } } -void FuncGraph::SetEffectDepends(const std::vector& depend_inputs) { +void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { auto old_ret = output(); std::vector inputs{NewValueNode(prim::kPrimDepend), old_ret}; (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index d90cdbacf27..c086b8d7d18 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -26,29 +26,29 @@ // namespace to support intermediate representation definition namespace mindspore { -Cloner::Cloner(const FuncGraphPtrList& func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, - bool clone_all_used_graphs, const TraceInfoPtr& relation, const TraceInfoPtr& target_relation) +Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, + bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) : clone_all_valuenodes_(clone_all_valuenodes), clone_all_child_graphs_(clone_all_child_graphs), clone_all_used_graphs_(clone_all_used_graphs), relation_(relation), target_relation_(target_relation == nullptr ? relation : target_relation) { - for (auto& func_graph : func_graphs) { + for (auto &func_graph : func_graphs) { AddClone(func_graph); } scope_ = kDefaultScope; type_ = kBasic; } -void Cloner::AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& params, CloneType type) { +void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList ¶ms, CloneType type) { if (func_graph != nullptr) { todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); type_ = type; } } -void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); if (repl_node_.find(node) != repl_node_.end() || node->isa()) { return; @@ -60,7 +60,7 @@ void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { } } -void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add) { +void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, TraceManager::EndTrace(); } -void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -91,7 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { TraceManager::EndTrace(); } -void Cloner::CloneValueNode(const AnfNodePtr& node) { +void Cloner::CloneValueNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); TraceManager::DebugTrace(node->debug_info(), relation_); ValueNodePtr new_const = NewValueNode(GetValueNode(node)); @@ -102,7 +102,7 @@ void Cloner::CloneValueNode(const AnfNodePtr& node) { TraceManager::EndTrace(); } -void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -114,14 +114,14 @@ void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) TraceManager::EndTrace(); } -void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { +void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_valuenodes_) { return; } - auto& value_nodes = manager_->valuenodes()[func_graph]; - for (auto& value_node : value_nodes) { + auto &value_nodes = manager_->valuenodes()[func_graph]; + for (auto &value_node : value_nodes) { auto old_node = value_node.first; MS_EXCEPTION_IF_NULL(old_node); if (repl_node_.count(old_node) == 0) { @@ -130,38 +130,38 @@ void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { } } -void Cloner::AddChildGraphs(const FuncGraphPtr& func_graph) { +void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_child_graphs_) { return; } - auto& scopes = manager_->scopes(func_graph); - for (auto& graph : scopes) { + auto &scopes = manager_->scopes(func_graph); + for (auto &graph : scopes) { if (graph != func_graph) { todo_.push_back({graph, nullptr, {}}); } } } -void Cloner::AddTotalGraphs(const FuncGraphPtr& func_graph) { +void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_used_graphs_) { return; } - auto& used_graphs = manager_->func_graphs_used()[func_graph]; - for (auto& used_graph : used_graphs) { + auto &used_graphs = manager_->func_graphs_used()[func_graph]; + for (auto &used_graph : used_graphs) { todo_.push_back({used_graph.first, nullptr, {}}); } } -void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); - for (auto& item : func_graph->parameter_default_value()) { + for (auto &item : func_graph->parameter_default_value()) { auto nodes = DeepLinkedGraphSearch(item.second); - for (auto& node : nodes) { + for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { CloneNode(node, target_func_graph); @@ -172,7 +172,7 @@ void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const F } } -void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(manager_); @@ -182,15 +182,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const Func } target_func_graph->set_return(return_node); - auto& value_nodes = manager_->func_graph_valuenodes()[func_graph]; - for (auto& value_node : value_nodes) { + auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; + for (auto &value_node : value_nodes) { CloneValueNode(value_node.first, target_func_graph); } } -void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params) { +void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { MS_EXCEPTION_IF_NULL(func_graph); - auto& old_params = func_graph->parameters(); + auto &old_params = func_graph->parameters(); if (old_params.size() != params.size()) { MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; return; @@ -200,7 +200,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNode } } -void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph) { +void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); @@ -215,33 +215,33 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* cons TraceManager::EndTrace(); } -void Cloner::CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); - auto& params = func_graph->parameters(); - for (auto& param : params) { + auto ¶ms = func_graph->parameters(); + for (auto ¶m : params) { CloneParameter(param, target_func_graph, true); } repl_func_graph_[func_graph] = target_func_graph; } -void Cloner::GenParameters(const FuncGraphPtr& func_graph) { +void Cloner::GenParameters(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - auto& free_vars = manager_->free_variables_total(); + auto &free_vars = manager_->free_variables_total(); auto iter = free_vars.find(func_graph); if (iter == free_vars.end()) { return; } - for (auto& fv_map : iter->second) { - auto& free_var = fv_map.first; + for (auto &fv_map : iter->second) { + auto &free_var = fv_map.first; if (utils::isa(free_var)) { repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast(free_var))); } } } -void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { +void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { param->set_abstract(node->abstract()); if (node->isa()) { ParameterPtr old_param = dyn_cast(node); @@ -252,7 +252,7 @@ void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { } } -ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add) { +ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { TraceManager::DebugTrace(std::make_shared(node->debug_info())); ParameterPtr param = std::make_shared(func_graph); TraceManager::EndTrace(); @@ -265,11 +265,11 @@ ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodeP return param; } -void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, - AnfNodePtrList* const lift_params, AnfNodePtrList* const input_params) { +void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, + AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { AnfNodePtrList parameters; std::unordered_set old_params; - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { auto iter = repl_node_.find(param); if (iter != repl_node_.end()) { (void)old_params.insert(iter->second); @@ -280,7 +280,7 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& } } AnfNodePtr new_param = nullptr; - for (auto& param : params) { + for (auto ¶m : params) { auto old_param = repl_node_[param]; if (old_param->isa() && old_param->func_graph() == func_graph) { repl_node_[old_param] = old_param; @@ -301,10 +301,10 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& func_graph->set_parameters(parameters); } -void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params) { +void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { AnfNodePtr node = nullptr; - auto& repl_func_graph = repl_map_func_graph_[func_graph_user]; + auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; auto iter = repl_func_graph.find(func_graph); if (iter == repl_func_graph.end()) { node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); @@ -322,9 +322,9 @@ void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& OrderParameters(func_graph, inputs); } -void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs) { +void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { std::unordered_set old_params; - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { (void)old_params.insert(repl_node_[param]); } std::unordered_set new_params; @@ -339,7 +339,7 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis (void)new_params.insert(new_param); } } - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { if (new_params.find(param) == new_params.end()) { parameters.push_back(param); } @@ -347,9 +347,9 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis func_graph->set_parameters(parameters); } -void Cloner::SetEdges(const FuncGraphPtr& func_graph) { +void Cloner::SetEdges(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - for (auto& node : func_graph->nodes()) { + for (auto &node : func_graph->nodes()) { if (node == nullptr) { continue; } @@ -358,17 +358,17 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { continue; } auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); for (size_t i = 0; i < inputs.size(); i++) { - auto& input = inputs[i]; + auto &input = inputs[i]; if (IsValueNode(input)) { auto graph = GetValueNode(input); - auto& repl_func_graph = repl_map_func_graph_[func_graph]; + auto &repl_func_graph = repl_map_func_graph_[func_graph]; if (repl_func_graph.find(graph) != repl_func_graph.end()) { transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); } } else { - auto& repl_node = repl_map_node_[func_graph]; + auto &repl_node = repl_map_node_[func_graph]; if (repl_node.find(input) != repl_node.end()) { transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); } @@ -377,8 +377,8 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { } } -void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params) { +void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { AnfNodePtrList lift_params; AnfNodePtrList input_params; AddParameters(func_graph_user, params, &lift_params, &input_params); @@ -386,16 +386,16 @@ void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraph if (lift_params.empty()) { return; } - for (auto& user : func_graph_user->func_graph_users()) { + for (auto &user : func_graph_user->func_graph_users()) { LiftParameters(user.first, func_graph_user, lift_params); } } void Cloner::Lift() { - for (auto& func_graph_params : repl_func_graph_params_) { - auto& func_graph = func_graph_params.first; - auto& params = func_graph_params.second; - for (auto& user : func_graph->func_graph_users()) { + for (auto &func_graph_params : repl_func_graph_params_) { + auto &func_graph = func_graph_params.first; + auto ¶ms = func_graph_params.second; + for (auto &user : func_graph->func_graph_users()) { LiftParameters(user.first, func_graph, params); } } @@ -404,18 +404,18 @@ void Cloner::Lift() { void Cloner::LiftParameters() { MS_EXCEPTION_IF_NULL(manager_); transaction_ = manager_->Transact(); - const FuncGraphSet& func_graphs = manager_->func_graphs(); - for (auto& func_graph : func_graphs) { + const FuncGraphSet &func_graphs = manager_->func_graphs(); + for (auto &func_graph : func_graphs) { GenParameters(func_graph); } Lift(); - for (auto& func_graph : func_graphs) { + for (auto &func_graph : func_graphs) { SetEdges(func_graph); } transaction_.Commit(); } -bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { +bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { MS_EXCEPTION_IF_NULL(func_graph); // Make sure only inline once if (status_.count(func_graph) != 0) { @@ -430,12 +430,12 @@ bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { return true; } -void Cloner::CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(manager_); - const AnfNodeSet& nodes = manager_->nodes()[func_graph]; - for (auto& node : nodes) { + const AnfNodeSet &nodes = manager_->nodes()[func_graph]; + for (auto &node : nodes) { CloneNode(node, target_func_graph); } } @@ -449,7 +449,7 @@ void Cloner::Run() { // Basic and Inline Clone FuncGraphPtrList func_graphs; (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), - [](const CloneInfo& item) -> FuncGraphPtr { return item.origin; }); + [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); manager_ = Manage(func_graphs, false); CloneNodes(); LinkEdges(); @@ -495,13 +495,13 @@ void Cloner::CloneNodes() { } void Cloner::LinkEdges() { - for (auto& node_pair : nodes_) { + for (auto &node_pair : nodes_) { CNodePtr old_node = node_pair.first; CNodePtr new_node = node_pair.second; MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(new_node); - for (auto& input : old_node->inputs()) { - auto& new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; + for (auto &input : old_node->inputs()) { + auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; new_node->add_input(new_input); } } @@ -509,10 +509,10 @@ void Cloner::LinkEdges() { // For the graphs cloned, update its default value map to the cloned nodes void Cloner::SetDefaults() { - for (auto& item : graph_set_) { + for (auto &item : graph_set_) { MS_EXCEPTION_IF_NULL(item); if (repl_func_graph_.count(item) != 0) { - for (auto& param_def : item->parameter_default_value()) { + for (auto ¶m_def : item->parameter_default_value()) { MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); if (repl_node_.count(param_def.second) != 0) { repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); @@ -524,7 +524,7 @@ void Cloner::SetDefaults() { } } -AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { +AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { MS_EXCEPTION_IF_NULL(root); if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; @@ -537,7 +537,7 @@ AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; } -AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { +AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { #ifdef ENABLE_PROFILE double time = GetTime(); #endif @@ -548,7 +548,7 @@ AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); } -FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { +FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { #ifdef ENABLE_PROFILE double time = GetTime(); #endif @@ -559,14 +559,14 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); } -FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph) { +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); Cloner cloner({func_graph}, false, true, true, std::make_shared(), nullptr); return cloner[func_graph]; } -AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& func_graph_args, const ScopePtr& scope) { +AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); Cloner cloner({}, false); @@ -577,14 +577,14 @@ AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& targe return cloner[func_graph->output()]; } -FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph) { +FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); Cloner cloner({}, false); cloner.AddClone(func_graph, nullptr, {}, kLifting); return cloner[func_graph]; } -ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { +ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { MS_EXCEPTION_IF_NULL(func_graph); FuncGraphPtrList func_graphs = {func_graph}; ClonerPtr cloner = @@ -599,14 +599,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& r return cloner; } -FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { +FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { MS_EXCEPTION_IF_NULL(func_graph); TraceManager::DebugTrace(func_graph->debug_info(), relation); auto new_func_graph = std::make_shared(); TraceManager::EndTrace(); - auto& parameters = func_graph->parameters(); - (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr& param) -> void { + auto ¶meters = func_graph->parameters(); + (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { MS_EXCEPTION_IF_NULL(param); TraceManager::DebugTrace(std::make_shared(param->debug_info())); (void)new_func_graph->add_parameter(); @@ -622,7 +622,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoP new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_is_generate(func_graph->is_generated()); - for (auto& item : func_graph->parameter_default_value()) { + for (auto &item : func_graph->parameter_default_value()) { new_func_graph->set_param_default_value(item.first, cloner[item.second]); } diff --git a/mindspore/ccsrc/ir/func_graph_cloner.h b/mindspore/ccsrc/ir/func_graph_cloner.h index dd228cf79f2..426cf447a3f 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.h +++ b/mindspore/ccsrc/ir/func_graph_cloner.h @@ -43,26 +43,26 @@ struct CloneInfo { class Cloner { public: - explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false, + explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false, bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, - const TraceInfoPtr& relation = std::make_shared(), - const TraceInfoPtr& target_relation = nullptr); + const TraceInfoPtr &relation = std::make_shared(), + const TraceInfoPtr &target_relation = nullptr); ~Cloner() = default; - void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr, - const AnfNodePtrList& params = {}, CloneType type = kBasic); + void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr, + const AnfNodePtrList ¶ms = {}, CloneType type = kBasic); void Run(); // Interfaces for specializer - AnfNodePtr CloneDisconnected(const AnfNodePtr& root); - AnfNodePtr operator[](const AnfNodePtr& node); - FuncGraphPtr operator[](const FuncGraphPtr& func_graph); + AnfNodePtr CloneDisconnected(const AnfNodePtr &root); + AnfNodePtr operator[](const AnfNodePtr &node); + FuncGraphPtr operator[](const FuncGraphPtr &func_graph); // Map of replicate nodes and graphs - std::unordered_map* cloned_node() { return &repl_node_; } + std::unordered_map *cloned_node() { return &repl_node_; } std::unordered_map cloned_func_graph() { return repl_func_graph_; } // Scope of cloned graphs - void set_scope(const ScopePtr& scope) { scope_ = scope; } + void set_scope(const ScopePtr &scope) { scope_ = scope; } const ScopePtr scope() const { return scope_; } std::unordered_map repl_node_; @@ -71,31 +71,31 @@ class Cloner { void CloneNodes(); void LinkEdges(); void SetDefaults(); - void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneValueNode(const AnfNodePtr& node); - void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false); - void CloneValueNodes(const FuncGraphPtr& func_graph); - void AddChildGraphs(const FuncGraphPtr& func_graph); - void AddTotalGraphs(const FuncGraphPtr& func_graph); - bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline); - void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params); - void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph); - void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void GenParameters(const FuncGraphPtr& func_graph); - void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node); - ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true); - void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params, - AnfNodePtrList* const input_params); - void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params); - void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs); - void SetEdges(const FuncGraphPtr& func_graph); - void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params); + void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneValueNode(const AnfNodePtr &node); + void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false); + void CloneValueNodes(const FuncGraphPtr &func_graph); + void AddChildGraphs(const FuncGraphPtr &func_graph); + void AddTotalGraphs(const FuncGraphPtr &func_graph); + bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline); + void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); + void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph); + void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void GenParameters(const FuncGraphPtr &func_graph); + void CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node); + ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true); + void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, AnfNodePtrList *const lift_params, + AnfNodePtrList *const input_params); + void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); + void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs); + void SetEdges(const FuncGraphPtr &func_graph); + void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms); void Lift(); void LiftParameters(); @@ -118,17 +118,17 @@ class Cloner { std::unordered_map repl_func_graph_params_; }; -FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph); +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph); -AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr); +AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr); -FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph); +FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph); -ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation); +ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation); -FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, - const TraceInfoPtr& relation = std::make_shared()); +FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, + const TraceInfoPtr &relation = std::make_shared()); } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 889a091711a..a53c9e95aeb 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -27,17 +27,17 @@ namespace mindspore { -FuncGraphManagerPtr MakeManager(const std::vector& func_graphs, bool manage) { +FuncGraphManagerPtr MakeManager(const std::vector &func_graphs, bool manage) { auto m = std::make_shared(func_graphs, manage); m->Init(); return m; } -FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool manage) { +FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage) { FuncGraphManagerPtr m = nullptr; bool root = false; - for (auto& fg : func_graphs) { + for (auto &fg : func_graphs) { if (fg == nullptr) { continue; } @@ -53,7 +53,7 @@ FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool ma root = true; } - for (auto& fg : func_graphs) { + for (auto &fg : func_graphs) { if (fg == nullptr) { continue; } @@ -67,7 +67,7 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { return Manage(func_graphs, manage); } -FuncGraphManager::FuncGraphManager(const std::vector& roots, bool manage) +FuncGraphManager::FuncGraphManager(const std::vector &roots, bool manage) : roots_(roots), is_manage_(manage) { Reset(); } @@ -103,12 +103,12 @@ void FuncGraphManager::Init() { auto roots = roots_; roots_ = FuncGraphSet(); - for (auto& fg : roots) { + for (auto &fg : roots) { AddFuncGraph(fg, true); } } -FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); func_graph_parents_total_->Recompute(fg); @@ -116,7 +116,7 @@ FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; } -FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { +FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(func_graph_parent_); MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); @@ -129,7 +129,7 @@ FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { return func_graph_parent_->parent_analysis()[fg]; } -FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(children_); MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); @@ -137,7 +137,7 @@ FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { return children_->children_analysis()[fg]; } -FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(scopes_); MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); @@ -146,19 +146,19 @@ FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { return scopes_->scope_analysis()[fg]; } -FVTotalMap& FuncGraphManager::free_variables_total() const { +FVTotalMap &FuncGraphManager::free_variables_total() const { MS_EXCEPTION_IF_NULL(free_variables_total_); free_variables_total_->Recompute(); return free_variables_total_->fv_total_analysis(); } -FuncGraphSet& FuncGraphManager::func_graphs_used_total(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(func_graphs_used_total_); func_graphs_used_total_->Recompute(fg); return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; } -bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { +bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); recursive_->Recompute(fg); if (recursive_->recursive_analysis().count(fg) == 0) { @@ -168,7 +168,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { return recursive_->recursive_analysis()[fg]; } -std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr& fg) const { +std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); if (recursive(fg)) { if (!recursive_->recursive_map().count(fg)) { @@ -185,7 +185,7 @@ std::shared_ptr> FuncGraphManager::recursive_graphs(cons } } -bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr& fg) const { +bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(j_total_); MS_EXCEPTION_IF_NULL(fg); j_total_->Recompute(fg); @@ -225,10 +225,10 @@ void FuncGraphManager::Clear() { signals_->InvalidateComputer(); } -void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { +void FuncGraphManager::KeepRoots(const std::vector &func_graphs) { MS_LOG(DEBUG) << "Start keep roots"; bool root_exist = false; - for (auto& item : func_graphs) { + for (auto &item : func_graphs) { if (roots_.contains(item)) { root_exist = true; break; @@ -245,17 +245,17 @@ void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { roots = roots_; } else { roots_.clear(); - for (auto& item : roots) { + for (auto &item : roots) { AddFuncGraph(item, true); } } FuncGraphSet keep; - for (auto& item : roots) { + for (auto &item : roots) { MS_LOG(DEBUG) << "roots: " << item->ToString(); keep.update(func_graphs_used_total(item)); #ifdef DEBUG - for (auto& k : keep) { + for (auto &k : keep) { MS_LOG(DEBUG) << "keep: " << k->ToString(); } #endif @@ -264,7 +264,7 @@ void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { } else { Clear(); FuncGraphSet roots(func_graphs); - for (auto& item : roots) { + for (auto &item : roots) { AddFuncGraph(item, true); } } @@ -276,7 +276,7 @@ void FuncGraphManager::RemoveRoots() { MaybeDropFuncGraphs(func_graphs_, true); } -void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { +void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(fg); if (is_manage_) { if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { @@ -288,7 +288,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { func_graphs_.add(fg); } -void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users) { +void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { FuncGraphSet todo(func_graphs); std::set dropped; // int count = 0; @@ -301,7 +301,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool continue; } MS_EXCEPTION_IF_NULL(func_graph_users_); - auto& users = func_graph_users_->count_func_graphs_map()[func_graph]; + auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; if (!users.empty() && !ignore_users) { MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); continue; @@ -315,7 +315,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool todo.update(MaybeDropNodes(return_vec)); } MS_EXCEPTION_IF_NULL(signals_); - for (auto& fg : dropped) { + for (auto &fg : dropped) { MS_EXCEPTION_IF_NULL(fg); signals_->DropFuncGraph(fg); all_nodes_.difference_update(fg->parameters()); @@ -331,7 +331,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E MS_EXCEPTION_IF_NULL(inp); if (direction == kDecEdge) { MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); - auto& users_node = node_users_[inp]; + auto &users_node = node_users_[inp]; if (!users_node.contains(make_pair(node, index))) { return; } @@ -346,26 +346,26 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); AddFuncGraph(GetValueNode(inp)); } - auto& users_node = node_users_[inp]; + auto &users_node = node_users_[inp]; users_node.add(make_pair(node, index)); MS_EXCEPTION_IF_NULL(signals_); signals_->AddEdge(node, index, inp); } } -void FuncGraphManager::ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction) { +void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto cnode = node->cast(); int index = 0; - for (auto& inp : cnode->inputs()) { + for (auto &inp : cnode->inputs()) { ProcessEdge(cnode, index, inp, direction); ++index; } } } -IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { +IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { if (all_nodes_.contains(node)) { return EXCLUDE; } else { @@ -373,9 +373,9 @@ IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { } } -void FuncGraphManager::AcquireNodes(const std::vector& nodes) { +void FuncGraphManager::AcquireNodes(const std::vector &nodes) { AnfNodeSet acq; - for (auto& node : nodes) { + for (auto &node : nodes) { std::function limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit)); @@ -384,7 +384,7 @@ void FuncGraphManager::AcquireNodes(const std::vector& nodes) { acq.update(new_nodes); } - for (auto& node : acq) { + for (auto &node : acq) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr fg = node->func_graph(); if (fg != nullptr) { @@ -395,7 +395,7 @@ void FuncGraphManager::AcquireNodes(const std::vector& nodes) { } } -FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& nodes) { +FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { AnfNodeSet nodes_ordered(nodes); FuncGraphSetPtr func_graphs_to_check = std::make_shared(); MS_EXCEPTION_IF_NULL(signals_); @@ -406,7 +406,7 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& if (!all_nodes_.contains(node)) { continue; } - AnfNodeIndexSet& users = node_users_[node]; + AnfNodeIndexSet &users = node_users_[node]; std::vector parameters; if (!users.empty() || @@ -431,13 +431,13 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& return func_graphs_to_check; } -void FuncGraphManager::SetParameters(const FuncGraphPtr& fg, const std::vector& parameters) { +void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters) { auto tr = Transact(); tr.SetParameters(fg, parameters); tr.Commit(); } -bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { +bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { auto tr = Transact(); bool success = tr.Replace(old_node, new_node); if (success) { @@ -446,13 +446,13 @@ bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new return success; } -void FuncGraphManager::SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value) { +void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { auto tr = Transact(); tr.SetEdge(node, index, value); tr.Commit(); } -void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope) { +void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { AnfNodePtr source_return = source->get_return(); AnfNodePtr source_output = source->output(); AnfNodePtr source_prim = source_return->cast()->input(0); @@ -466,23 +466,23 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t (void)all_nodes_.erase(source_return); (void)node_users_.erase(source_return); signals_->DropNode(source_return); - for (auto& node : source->nodes()) { + for (auto &node : source->nodes()) { node->set_func_graph(target); if (node->scope() == kDefaultScope) { node->set_scope(scope); } } - for (auto& used : source->func_graphs_used()) { + for (auto &used : source->func_graphs_used()) { (void)func_graph_users_->Inc(used.first, target, used.second); (void)this->func_graph_users()[used.first].erase(source); } - for (auto& child : this->func_graph_child_direct()[source]) { + for (auto &child : this->func_graph_child_direct()[source]) { (void)func_graph_parents_direct_->Inc(child.first, target, child.second); (void)this->func_graph_parents_direct()[child.first].erase(source); } - for (auto& fv_count : this->free_variables_direct()[source]) { + for (auto &fv_count : this->free_variables_direct()[source]) { auto fv_g = fv_count.first->func_graph(); - auto& count_on_g = this->func_graph_child_direct()[fv_g]; + auto &count_on_g = this->func_graph_child_direct()[fv_g]; auto pair = count_on_g.find(source); if (fv_g != target && pair != count_on_g.end()) { (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); @@ -504,9 +504,9 @@ FuncGraphTransaction FuncGraphManager::Transact() { return tr; } -void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupleCounter* add_edges, - EdgeTupleCounter* rm_edges, Counter* adds, Counter* rms) { - for (auto& iter : changes) { +void FuncGraphManager::ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, + EdgeTupleCounter *rm_edges, Counter *adds, Counter *rms) { + for (auto &iter : changes) { auto operation = iter.op; auto args = iter.args; if (operation == Change::kTxSetEdge) { @@ -521,10 +521,10 @@ void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupl auto param = args.cast(); MS_EXCEPTION_IF_NULL(param.func_graph); auto old_parameters = param.func_graph->parameters(); - for (auto& p : param.params) { + for (auto &p : param.params) { (*adds)[p] += 1; } - for (auto& p : old_parameters) { + for (auto &p : old_parameters) { (*rms)[p] += 1; } param.func_graph->set_parameters(param.params); @@ -532,7 +532,7 @@ void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupl } } -void FuncGraphManager::CommitChanges(const std::vector& changes) { +void FuncGraphManager::CommitChanges(const std::vector &changes) { EdgeTupleCounter add_edges; EdgeTupleCounter rm_edges; Counter adds; @@ -540,7 +540,7 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); auto sub_edges = add_edges - rm_edges; - for (auto& iter : sub_edges) { + for (auto &iter : sub_edges) { auto root_node = iter.first.first; int index = iter.first.second.first; auto new_node = iter.first.second.second; @@ -550,12 +550,12 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { auto sub_nodes = adds - rms; std::vector nodes; (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), - [](const std::pair& iter) -> AnfNodePtr { return iter.first; }); + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); AcquireNodes(nodes); auto sub_edges_reverse = rm_edges - add_edges; - for (auto& iter : sub_edges_reverse) { + for (auto &iter : sub_edges_reverse) { auto root_node = iter.first.first; int index = iter.first.second.first; auto old_node = iter.first.second.second; @@ -566,17 +566,17 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { std::vector nodes_reverse; (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), - [](const std::pair& iter) -> AnfNodePtr { return iter.first; }); + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); auto drop_func_graphs = MaybeDropNodes(nodes_reverse); MaybeDropFuncGraphs(*drop_func_graphs); } -void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector& params) { +void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector ¶ms) { changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); } -bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { +bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(new_node); FuncGraphPtr old_func_graph = old_node->func_graph(); @@ -585,14 +585,14 @@ bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& return false; } auto users = manager_->node_users()[old_node]; - for (auto& node : users) { + for (auto &node : users) { SetEdge(node.first, node.second, new_node); } return true; } -void FuncGraphTransaction::SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v) { +void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { if (k < 0) { MS_LOG(EXCEPTION) << "Invalid value k = " << k; } @@ -610,7 +610,7 @@ void FuncGraphTransaction::Commit() { manager_->CommitChanges(changes); } -FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) +FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) : manager_(manager), include_func_graph_none_(false) { manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); @@ -619,7 +619,7 @@ FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); } -NodesCollector::NodesCollector(const FuncGraphManager* const m) : DepCollector(m), nodes_analysis_() { +NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() { include_func_graph_none_ = true; nodes_analysis_[nullptr] = AnfNodeSet(); @@ -646,7 +646,7 @@ void NodesCollector::OnDropNode(AnfNodePtr n) { void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // change the owner of node except for the src's return node - for (auto& it : nodes_analysis_[src]) { + for (auto &it : nodes_analysis_[src]) { nodes_analysis_[dst].add(it); } (void)nodes_analysis_.erase(src); @@ -654,15 +654,15 @@ void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } -DepCollector::DepCollector(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { +DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); } void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } -bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { - auto& d = count_nodes_map_[func_graph]; +bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { + auto &d = count_nodes_map_[func_graph]; if (d.count(key) == 0) { d[key] = count; return true; @@ -672,9 +672,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { +bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { MS_EXCEPTION_IF_NULL(func_graph); - auto& d = count_nodes_map_[func_graph]; + auto &d = count_nodes_map_[func_graph]; if (d.count(key) != 0) { if (d[key] == count) { (void)d.erase(key); @@ -690,7 +690,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count) { +bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { if (count > 0) { return Inc(func_graph, key, count); } else if (count < 0) { @@ -701,8 +701,8 @@ bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodeP } } -bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { - auto& d = count_func_graphs_map_[func_graph]; +bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { + auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) == 0) { d[key] = count; return true; @@ -712,8 +712,8 @@ bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGr return false; } -bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { - auto& d = count_func_graphs_map_[func_graph]; +bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { + auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) != 0) { if (d[key] == count) { (void)d.erase(key); @@ -729,7 +729,7 @@ bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGr return false; } -bool CounterFuncGraphCollector::Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count) { +bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { if (count > 0) { return Inc(func_graph, key, count); } else if (count < 0) { @@ -748,7 +748,7 @@ void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgePr } void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -762,7 +762,7 @@ void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, Ed } void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -779,7 +779,7 @@ void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProc } void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { FuncGraphPtr fg2 = it.first->func_graph(); if (fg2 != dst) { (void)Inc(dst, it.first, it.second); @@ -788,7 +788,7 @@ void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { (void)count_nodes_map_.erase(src); } -static FuncGraphPtr ParentProxy(const FuncGraphPtr& fg) { +static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { FuncGraphPtr gn = std::make_shared(); (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); return gn; @@ -805,7 +805,7 @@ void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeP } void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { FuncGraphPtr fg = it.first; if (fg != dst) { (void)Inc(dst, fg, it.second); @@ -835,7 +835,7 @@ void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr } void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { if (it.first != dst) { (void)Inc(dst, it.first, it.second); } @@ -852,7 +852,7 @@ void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, Ed void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // all graph use in src need to change to dst, so meger the to dst use - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_func_graphs_map_[dst].erase(src); @@ -879,7 +879,7 @@ void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp } void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -895,13 +895,13 @@ void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // all graph use in src need to change to dst, so meger the to dst use - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_func_graphs_map_.erase(src); } -DepComputer::DepComputer(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { +DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); validate_ = false; @@ -914,20 +914,20 @@ void DepComputer::Recompute() { } } -void DepComputer::Recompute(const FuncGraphPtr& fg) { +void DepComputer::Recompute(const FuncGraphPtr &fg) { if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { RealRecompute(fg); func_graphs_validate_[fg] = true; } } -FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { +FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { if (path == nullptr || path->contains(fg)) { return std::make_shared(); } FuncGraphSetPtr parents = std::make_shared(); - FuncGraphToFuncGraphCounterMap& deps = *all_parents_direct_; - for (auto& dep : deps[fg]) { + FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_; + for (auto &dep : deps[fg]) { MS_EXCEPTION_IF_NULL(dep.first); auto proxy = dep.first->transforms().find("proxy"); if (proxy != dep.first->transforms().end()) { @@ -950,7 +950,7 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); } -bool set_len_compare(const FuncGraphSetPair& lhs, const FuncGraphSetPair& rhs) { +bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { auto l1 = lhs.second.size(); auto l2 = rhs.second.size(); return l1 < l2; @@ -970,9 +970,9 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { } else { // return nearest parent as parent FuncGraphSet deps_copy(deps); - for (auto& dep : deps) { + for (auto &dep : deps) { auto parent_deps = this->manager_->func_graph_parents_total(dep); - for (auto& p_d : parent_deps) { + for (auto &p_d : parent_deps) { if (deps_copy.count(p_d)) { (void)deps_copy.erase(p_d); } @@ -988,7 +988,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); auto used_fg_total = manager_->func_graphs_used_total(fg); - for (auto& used_fg : used_fg_total) { + for (auto &used_fg : used_fg_total) { if (manager_->parent(used_fg) == fg) { children_analysis_[fg].add(used_fg); } @@ -997,11 +997,11 @@ void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { void ScopeComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); - auto& children = manager_->children(fg); + auto &children = manager_->children(fg); scope_analysis_[fg] = FuncGraphSet(); scope_analysis_[fg].add(fg); - for (auto& child : children) { + for (auto &child : children) { scope_analysis_[fg].add(child); } } @@ -1010,20 +1010,20 @@ void FVTotalComputer::RealRecompute() { auto manager = DepComputer::manager_; MS_EXCEPTION_IF_NULL(manager); - for (auto& fg : manager->func_graphs()) { + for (auto &fg : manager->func_graphs()) { fv_total_analysis_[fg] = OrderedMap(); count_nodes_map_[fg] = OrderedMap(); count_func_graphs_map_[fg] = OrderedMap(); } - for (auto& fg : manager->func_graphs()) { + for (auto &fg : manager->func_graphs()) { AnfNodeCounterMap items = manager->free_variables_direct()[fg]; - for (auto& iter : items) { + for (auto &iter : items) { auto curr = fg; while (curr) { (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); curr = manager->parent(curr); - const AnfNodeSet& nodes = manager->nodes()[curr]; + const AnfNodeSet &nodes = manager->nodes()[curr]; if (nodes.contains(iter.first)) { break; } @@ -1031,7 +1031,7 @@ void FVTotalComputer::RealRecompute() { } auto items_fg = manager->func_graphs_used()[fg]; - for (auto& iter : items_fg) { + for (auto &iter : items_fg) { auto p = manager->parent(iter.first); if (p == nullptr) { continue; @@ -1043,13 +1043,13 @@ void FVTotalComputer::RealRecompute() { } } } - for (auto& fg : manager->func_graphs()) { - auto& fvp = count_nodes_map_[fg]; - auto& fvg = count_func_graphs_map_[fg]; - for (auto& item : fvp) { + for (auto &fg : manager->func_graphs()) { + auto &fvp = count_nodes_map_[fg]; + auto &fvg = count_func_graphs_map_[fg]; + for (auto &item : fvp) { fv_total_analysis_[fg][item.first] = item.second; } - for (auto& item : fvg) { + for (auto &item : fvg) { fv_total_analysis_[fg][item.first] = item.second; } } @@ -1057,15 +1057,15 @@ void FVTotalComputer::RealRecompute() { void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); - auto& used = this->manager_->func_graphs_used(); + auto &used = this->manager_->func_graphs_used(); std::vector todo; std::vector todo_new; todo.push_back(fg); while (!todo.empty()) { todo_new.clear(); - for (auto& gt : todo) { - for (auto& item : used[gt]) { + for (auto > : todo) { + for (auto &item : used[gt]) { auto used_fg = item.first; if (used_fg == fg) { func_graph_used_total_analysis_[fg].add(used_fg); @@ -1082,17 +1082,17 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { } } -bool CheckRecursive(const FuncGraphManager* const manager, const FuncGraphPtr& fg) { +bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(manager); - auto& used = manager->func_graphs_used(); + auto &used = manager->func_graphs_used(); std::vector todo; std::vector todo_new; todo.push_back(fg); FuncGraphSet used_total; while (!todo.empty()) { todo_new.clear(); - for (auto& gt : todo) { - for (auto& item : used[gt]) { + for (auto > : todo) { + for (auto &item : used[gt]) { auto used_g = item.first; if (used_g == fg) { return true; @@ -1112,7 +1112,7 @@ void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); } -void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list* trace) { +void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace) { MS_EXCEPTION_IF_NULL(trace); auto res = std::find(trace->begin(), trace->end(), fg); // find recursive @@ -1124,7 +1124,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::listpush_back(fg); - auto& used_fgs = manager_->func_graphs_used()[fg]; + auto &used_fgs = manager_->func_graphs_used()[fg]; for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { CheckRecursiveGraphs(iter->first, trace); } @@ -1135,14 +1135,14 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::listcontains(fg)) { MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked"; return false; } MS_EXCEPTION_IF_NULL(manager_); - auto& func_graph_counter_map = manager_->func_graph_j_direct(); + auto &func_graph_counter_map = manager_->func_graph_j_direct(); if (!func_graph_counter_map[fg].empty()) { // check g1->J(fg)->g2->g cycle; auto contains_j = @@ -1156,8 +1156,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt path->add(fg); // check if func graphs used contains J(func_graph); - auto& used = this->manager_->func_graphs_used(); - for (auto& item : used[fg]) { + auto &used = this->manager_->func_graphs_used(); + for (auto &item : used[fg]) { auto used_g = item.first; if (SeekJ(used_g, path)) { MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString() diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index aaf5a0aa5fc..54c1e8a6923 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -46,13 +46,13 @@ class FuncGraphManager; using FuncGraphManagerPtr = std::shared_ptr; struct AnfNodeIndexPairHasher { - std::size_t operator()(const std::pair& p1) const { - return std::hash{}(p1.first.get()); + std::size_t operator()(const std::pair &p1) const { + return std::hash{}(p1.first.get()); } }; struct AnfNodeIndexPairEqual { - bool operator()(const std::pair& lhs, const std::pair& rhs) const { + bool operator()(const std::pair &lhs, const std::pair &rhs) const { return lhs == rhs; } }; @@ -63,14 +63,14 @@ using FuncGraphSetPair = std::pair; using FuncGraphSetPtr = std::shared_ptr; using EdgeTuple = std::pair>; struct EdgeTupleHasher { - std::size_t operator()(const EdgeTuple& p1) const { - return hash_combine({std::hash{}(p1.first.get()), std::hash{}(p1.second.first), - std::hash{}(p1.second.second.get())}); + std::size_t operator()(const EdgeTuple &p1) const { + return hash_combine({std::hash{}(p1.first.get()), std::hash{}(p1.second.first), + std::hash{}(p1.second.second.get())}); } }; struct EdgeTupleEqual { - bool operator()(const EdgeTuple& lhs, const EdgeTuple& rhs) const { + bool operator()(const EdgeTuple &lhs, const EdgeTuple &rhs) const { return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second; } }; @@ -82,9 +82,9 @@ using EdgeTupleCounter = Counter; // FuncGraphManagerPtr: return created manager FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true); -FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool manage = true); +FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage = true); -FuncGraphManagerPtr MakeManager(const std::vector& func_graphs = {}, bool manage = true); +FuncGraphManagerPtr MakeManager(const std::vector &func_graphs = {}, bool manage = true); struct Signals { Signal AddFuncGraph; @@ -106,7 +106,7 @@ using FuncGraphToAnfNodeCounterMap = OrderedMap; // graphs analysis which compute in write, read needn't recompute class DepCollector : public FuncGraphAnalysis { public: - explicit DepCollector(const FuncGraphManager* manager); + explicit DepCollector(const FuncGraphManager *manager); ~DepCollector() override = default; void Reset() { ExtraReset(); } @@ -155,10 +155,10 @@ class DepCollector : public FuncGraphAnalysis { class NodesCollector final : public DepCollector { public: - explicit NodesCollector(const FuncGraphManager* m); + explicit NodesCollector(const FuncGraphManager *m); ~NodesCollector() override = default; - const FuncGraphToAnfNodeMap& nodes_analysis() const { return nodes_analysis_; } + const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; } size_t size() const override { return nodes_analysis_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } @@ -176,16 +176,16 @@ class NodesCollector final : public DepCollector { class CounterFuncGraphCollector : public DepCollector { public: - explicit CounterFuncGraphCollector(const FuncGraphManager* m) : DepCollector(m) {} + explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} ~CounterFuncGraphCollector() override = default; - FuncGraphToFuncGraphCounterMap& count_func_graphs_map() { return count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } // inherit from FuncGraphAnalysis size_t size() const override { return count_func_graphs_map_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap(); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } - bool Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); - bool Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); - bool Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); + bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); FuncGraphToFuncGraphCounterMap count_func_graphs_map_; @@ -195,17 +195,17 @@ class CounterFuncGraphCollector : public DepCollector { class CounterAnfNodeCollector : public DepCollector { public: - explicit CounterAnfNodeCollector(const FuncGraphManager* m) : DepCollector(m) {} + explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} ~CounterAnfNodeCollector() override = default; - FuncGraphToAnfNodeCounterMap& count_nodes_map() { return count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } size_t size() const override { return count_nodes_map_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap(); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - bool Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); - bool Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); - bool Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); + bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); FuncGraphToAnfNodeCounterMap count_nodes_map_; @@ -215,7 +215,7 @@ class CounterAnfNodeCollector : public DepCollector { class ValueNodesCollector final : public CounterAnfNodeCollector { public: - explicit ValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~ValueNodesCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -225,7 +225,7 @@ class ValueNodesCollector final : public CounterAnfNodeCollector { class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { public: - explicit FuncGraphValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~FuncGraphValueNodesCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -235,7 +235,7 @@ class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { class FVDirectCollector final : public CounterAnfNodeCollector { public: - explicit FVDirectCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~FVDirectCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -245,7 +245,7 @@ class FVDirectCollector final : public CounterAnfNodeCollector { class FuncGraphChildDirect final : public CounterFuncGraphCollector { public: - explicit FuncGraphChildDirect(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphChildDirect() override = default; @@ -260,7 +260,7 @@ class FuncGraphChildDirect final : public CounterFuncGraphCollector { // 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphParentsDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} ~FuncGraphParentsDirectCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -271,7 +271,7 @@ class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { // graph's all used graphs: key is g, value is g used graph class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphsUsedCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphsUsedCollector() override = default; @@ -282,7 +282,7 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { // graph's all user graphs: key is g, value is graphs who used g class FuncGraphUsersCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphUsersCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphUsersCollector() override = default; @@ -293,7 +293,7 @@ class FuncGraphUsersCollector final : public CounterFuncGraphCollector { // graph's all user cnodes: key is g, value is cnodes who used g class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { public: - explicit FuncGraphUserNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphUserNodesCollector() override = default; @@ -303,7 +303,7 @@ class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphJDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; ~FuncGraphJDirectCollector() override = default; @@ -316,7 +316,7 @@ using FuncGraphToFuncGraphSetMap = OrderedMap; // graphs analysis which need dynamic compute by DepCollector in each read class DepComputer : public FuncGraphAnalysis { public: - explicit DepComputer(const FuncGraphManager* manager); + explicit DepComputer(const FuncGraphManager *manager); ~DepComputer() override = default; void Reset() { @@ -329,11 +329,11 @@ class DepComputer : public FuncGraphAnalysis { void Recompute(); - void Recompute(const FuncGraphPtr& fg); + void Recompute(const FuncGraphPtr &fg); bool IsValidate() const { return validate_; } - bool IsValidate(const FuncGraphPtr& fg) { return func_graphs_validate_[fg]; } + bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; } void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } @@ -354,10 +354,10 @@ class DepComputer : public FuncGraphAnalysis { // graph g's all direct or proxy parents class FuncGraphParentsTotalComputer final : public DepComputer { public: - explicit FuncGraphParentsTotalComputer(const FuncGraphManager* m) : DepComputer(m), all_parents_direct_(nullptr) {} + explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } - FuncGraphToFuncGraphSetMap& func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } + FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } size_t size() const override { return func_graph_parents_total_analysis_.size(); } @@ -369,10 +369,10 @@ class FuncGraphParentsTotalComputer final : public DepComputer { void RealRecompute(FuncGraphPtr fg) override; private: - FuncGraphSetPtr SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared()); + FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared()); // when SeekParents calls itself recursively, it can access these variables by class member // other than pass by formal parameters, it can save 1 parameter for SeekParents(). - FuncGraphToFuncGraphCounterMap* all_parents_direct_; + FuncGraphToFuncGraphCounterMap *all_parents_direct_; }; using FuncGraphToFuncGraphMap = OrderedMap; @@ -380,10 +380,10 @@ using FuncGraphToFuncGraphMap = OrderedMap; // graph's nearest parent in parents total class ParentComputer final : public DepComputer { public: - explicit ParentComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ParentComputer() override = default; - FuncGraphToFuncGraphMap& parent_analysis() { return parent_analysis_; } + FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; } size_t size() const override { return parent_analysis_.size(); } @@ -398,10 +398,10 @@ class ParentComputer final : public DepComputer { // graph's children graph except self class ChildrenComputer final : public DepComputer { public: - explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ChildrenComputer() override = default; - FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; } + FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; } size_t size() const override { return children_analysis_.size(); } @@ -416,10 +416,10 @@ class ChildrenComputer final : public DepComputer { // graph's children graph include self class ScopeComputer final : public DepComputer { public: - explicit ScopeComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ScopeComputer() override = default; - FuncGraphToFuncGraphSetMap& scope_analysis() { return scope_analysis_; } + FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; } size_t size() const override { return scope_analysis_.size(); } @@ -435,11 +435,11 @@ using FVTotalMap = OrderedMap* trace); + void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace); size_t size() const override { return recursive_analysis_.size(); } @@ -497,10 +497,10 @@ class RecursiveComputer final : public DepComputer { class FuncGraphJTotalComputer final : public DepComputer { public: - explicit FuncGraphJTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} ~FuncGraphJTotalComputer() override = default; - FuncGraphToBoolMap& j_total_analysis() { return j_total_analysis_; } + FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; } size_t size() const override { return j_total_analysis_.size(); } @@ -510,12 +510,12 @@ class FuncGraphJTotalComputer final : public DepComputer { void ExtraReset() override { j_total_analysis_.clear(); } void RealRecompute(FuncGraphPtr fg) override; - bool SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path); + bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); }; class FuncGraphManager : public std::enable_shared_from_this { public: - explicit FuncGraphManager(const std::vector& roots, bool manage = true); + explicit FuncGraphManager(const std::vector &roots, bool manage = true); ~FuncGraphManager() { if (is_manage_) { RemoveRoots(); @@ -526,71 +526,71 @@ class FuncGraphManager : public std::enable_shared_from_this { void Init(); void Clear(); void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false); - void KeepRoots(const std::vector& roots = {}); + void KeepRoots(const std::vector &roots = {}); void RemoveRoots(); - void SetParameters(const FuncGraphPtr& fg, const std::vector& parameters); - void MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users = false); - bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); - void SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value); - void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope); + void SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters); + void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); + bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); + void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); + void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope); FuncGraphTransaction Transact(); - void CommitChanges(const std::vector& changes); + void CommitChanges(const std::vector &changes); bool IsManaged() const { return is_manage_; } - const FuncGraphSet& roots() const { return roots_; } + const FuncGraphSet &roots() const { return roots_; } - const FuncGraphSet& func_graphs() const { return func_graphs_; } + const FuncGraphSet &func_graphs() const { return func_graphs_; } - AnfNodeSet& all_nodes() { return all_nodes_; } + AnfNodeSet &all_nodes() { return all_nodes_; } - NodeUsersMap& node_users() { return node_users_; } + NodeUsersMap &node_users() { return node_users_; } - FuncGraphToAnfNodeMap& nodes() const { return nodes_->nodes_analysis_; } + FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } - FuncGraphToAnfNodeCounterMap& valuenodes() const { return valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } - FuncGraphToAnfNodeCounterMap& free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } - FuncGraphToAnfNodeCounterMap& func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } - FuncGraphToFuncGraphCounterMap& func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } - FuncGraphToAnfNodeCounterMap& func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_child_direct() const { + FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { return func_graph_child_direct_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_parents_direct() const { + FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const { return func_graph_parents_direct_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } - FVTotalMap& free_variables_total() const; + FVTotalMap &free_variables_total() const; - FuncGraphSet& func_graph_parents_total(const FuncGraphPtr& fg) const; + FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; - FuncGraphSet& scopes(const FuncGraphPtr& fg) const; + FuncGraphSet &scopes(const FuncGraphPtr &fg) const; - FuncGraphPtr parent(const FuncGraphPtr& fg) const; + FuncGraphPtr parent(const FuncGraphPtr &fg) const; - FuncGraphSet& children(const FuncGraphPtr& fg) const; + FuncGraphSet &children(const FuncGraphPtr &fg) const; - FuncGraphSet& func_graphs_used_total(const FuncGraphPtr& fg) const; + FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const; - bool recursive(const FuncGraphPtr& fg) const; - std::shared_ptr> recursive_graphs(const FuncGraphPtr& fg) const; + bool recursive(const FuncGraphPtr &fg) const; + std::shared_ptr> recursive_graphs(const FuncGraphPtr &fg) const; - bool func_graph_j_total(const FuncGraphPtr& fg) const; + bool func_graph_j_total(const FuncGraphPtr &fg) const; std::shared_ptr signals() const { return signals_; } - IncludeType Limit(const AnfNodePtr& node); + IncludeType Limit(const AnfNodePtr &node); // Static Analysis NodeUsersMap node_users_; @@ -610,13 +610,13 @@ class FuncGraphManager : public std::enable_shared_from_this { std::shared_ptr func_graph_parent_; private: - void AddIntoManaged(const FuncGraphPtr& fg); + void AddIntoManaged(const FuncGraphPtr &fg); void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction); - void ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction); - void AcquireNodes(const std::vector& nodes); - FuncGraphSetPtr MaybeDropNodes(const std::vector& nodes); - void ParseChanges(const std::vector& changes, EdgeTupleCounter* add_edges, EdgeTupleCounter* rm_edges, - Counter* adds, Counter* rms); + void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction); + void AcquireNodes(const std::vector &nodes); + FuncGraphSetPtr MaybeDropNodes(const std::vector &nodes); + void ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, + Counter *adds, Counter *rms); FuncGraphSet roots_; // managed roots FuncGraphSet func_graphs_; // managed func graphs @@ -637,7 +637,7 @@ class FuncGraphManager : public std::enable_shared_from_this { class FuncGraphTransaction { public: - explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() { + explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager), changes_() { MS_EXCEPTION_IF_NULL(manager_); if (!manager_->IsManaged()) { MS_LOG(DEBUG) << "The manager is not managed yet"; @@ -648,19 +648,19 @@ class FuncGraphTransaction { ~FuncGraphTransaction() { manager_ = nullptr; } // set parameters of a func graph - void SetParameters(FuncGraphPtr fg, const std::vector& params); + void SetParameters(FuncGraphPtr fg, const std::vector ¶ms); // replace old_node with new_node - bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); + bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); // set esge, i.e., declare setting node.inputs[key] to value. - void SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v); + void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v); // commit all changes void Commit(); private: - FuncGraphManager* manager_; + FuncGraphManager *manager_; std::vector changes_; }; @@ -668,9 +668,9 @@ class FuncGraphTransaction { struct ArgsOfSetParams { FuncGraphPtr func_graph; std::vector params; - bool operator==(const ArgsOfSetParams& other) const { return &other == this; } + bool operator==(const ArgsOfSetParams &other) const { return &other == this; } - friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetParams&) { + friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetParams &) { os << "[ArgsOfSetParams]"; return os; } @@ -681,9 +681,9 @@ struct ArgsOfSetEdge { CNodePtr root_node; AnfNodePtr new_node; size_t index; - bool operator==(const ArgsOfSetEdge& other) const { return &other == this; } + bool operator==(const ArgsOfSetEdge &other) const { return &other == this; } - friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetEdge& other) { + friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetEdge &other) { os << "[ArgsOfSetEdge]"; return os; } @@ -693,7 +693,7 @@ struct Change { enum OpName { kTxSetParams, kTxSetEdge }; OpName op; Any args; - Change(OpName name, const Any& para) : op(name), args(para) {} + Change(OpName name, const Any ¶) : op(name), args(para) {} }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/meta_func_graph.h b/mindspore/ccsrc/ir/meta_func_graph.h index 69da925e3df..482b5f90253 100644 --- a/mindspore/ccsrc/ir/meta_func_graph.h +++ b/mindspore/ccsrc/ir/meta_func_graph.h @@ -42,25 +42,25 @@ namespace mindspore { // generate a graph corresponding to these types. class MetaFuncGraph : public FuncGraphBase { public: - explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); } + explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); } ~MetaFuncGraph() override = default; MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); - abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node); + abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node); // Return normalized versions of the arguments. // By default, this returns args unchanged. - virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const { + virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { return args_spec_list; } - const std::vector& signatures() const { return signatures_; } - void set_signatures(const std::vector& signatures) { signatures_ = signatures; } + const std::vector &signatures() const { return signatures_; } + void set_signatures(const std::vector &signatures) { signatures_ = signatures; } // Generate a Graph for the given abstract arguments. - virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) { + virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { TypePtrList types; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), - [](const AbstractBasePtr& arg) -> TypePtr { + [](const AbstractBasePtr &arg) -> TypePtr { MS_EXCEPTION_IF_NULL(arg); return arg->BuildType(); }); @@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase { } // Generate a Graph for this type signature. - virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) { + virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; } @@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase { std::string ToString() const override { return name_; } std::size_t hash() const override { return tid(); } - virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; } - bool operator==(const Value& other) const override { + virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } + bool operator==(const Value &other) const override { if (other.isa()) { return &other == this; } else { diff --git a/mindspore/ccsrc/ir/meta_tensor.cc b/mindspore/ccsrc/ir/meta_tensor.cc index e9221039a77..5bb9ae3c065 100644 --- a/mindspore/ccsrc/ir/meta_tensor.cc +++ b/mindspore/ccsrc/ir/meta_tensor.cc @@ -31,7 +31,7 @@ namespace mindspore { namespace tensor { -void DataBuf2Contiguous(const py::array& src, py::array* const dest) { +void DataBuf2Contiguous(const py::array &src, py::array *const dest) { if (dest == nullptr) { MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; } @@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) { // MetaTensor has default type_id_ which is TypeId::kTypeUnknown. MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} -MetaTensor::MetaTensor(const TypeId data_type, const std::vector& shape) : data_type_(data_type), shape_(shape) {} +MetaTensor::MetaTensor(const TypeId data_type, const std::vector &shape) : data_type_(data_type), shape_(shape) {} -MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { +MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { } } -MetaTensor::MetaTensor(const MetaTensor& meta_tensor) +MetaTensor::MetaTensor(const MetaTensor &meta_tensor) : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {} -MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { +MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) { if (&meta_tensor == this) { return *this; } @@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { return *this; } -bool MetaTensor::operator==(const MetaTensor& meta_tensor) const { +bool MetaTensor::operator==(const MetaTensor &meta_tensor) const { return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape(); } @@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) { return type_ptr; } -void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) { +void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) { DeviceInfo info(format, data_type); set_device_info(info); } @@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const { return oss.str(); } -Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { +Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { init(data_type_, shape_, &data_); } -Tensor::Tensor(TypeId data_type, const std::vector& shape) { init(data_type, shape, &data_); } +Tensor::Tensor(TypeId data_type, const std::vector &shape) { init(data_type, shape, &data_); } -Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); } +Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); } -Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type) +Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type) : MetaTensor(tensor), device_address_(tensor.device_address()) { init(tensor.data_, data_type); } -Tensor& Tensor::operator=(const Tensor& tensor) { +Tensor &Tensor::operator=(const Tensor &tensor) { if (this != &tensor) { MetaTensor::operator=(tensor); dirty_ = tensor.is_dirty(); @@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) { return *this; } -bool Tensor::operator==(const Tensor& tensor) const { +bool Tensor::operator==(const Tensor &tensor) const { return (MetaTensor::operator==(tensor) && data_ == tensor.data_); } -bool Tensor::ValueEqualPy(const py::object& other) const { +bool Tensor::ValueEqualPy(const py::object &other) const { if (!py::isinstance(other)) { MS_LOG(WARNING) << "compare other not a tensor"; return false; @@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const { return ValueEqual(py::cast(other)); } -bool Tensor::ValueEqual(const Tensor& other) const { +bool Tensor::ValueEqual(const Tensor &other) const { auto equal = [&other, this]() -> bool { auto np = py::module::import("numpy"); auto equal = np.attr("equal")(data_, other.data_); @@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast(data_type_); } std::vector Tensor::shape_c(void) const { return shape(); } -void* Tensor::data_c(bool writable) { +void *Tensor::data_c(bool writable) { // operand of bit operation should be unsigned int. unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; bool is_c_contiguous = (flags != 0) ? true : false; @@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) { return data_.request(writable).ptr; } -TypeId Tensor::GetDataType(const py::buffer_info& buf) const { +TypeId Tensor::GetDataType(const py::buffer_info &buf) const { TypeId data_type = TypeId::kTypeUnknown; if (buf.format.compare("e") == 0) { data_type = TypeId::kNumberTypeFloat16; @@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const { return data_type; } -void Tensor::init(const py::array& input, const TypePtr& type_ptr) { +void Tensor::init(const py::array &input, const TypePtr &type_ptr) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) { init(input, data_type); } -void Tensor::init(const py::array& input, const TypeId& data_type) { +void Tensor::init(const py::array &input, const TypeId &data_type) { py::buffer_info buf = input.request(); data_type_ = GetDataType(buf); @@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) { } } -void Tensor::init(TypeId data_type, const std::vector& shape, py::array* const data) { +void Tensor::init(TypeId data_type, const std::vector &shape, py::array *const data) { data_type_ = data_type; shape_ = shape; switch (data_type) { @@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) { return data_type_; } -bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out, +bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out, const TypeId out_data_type) { if (out == nullptr) { return false; @@ -458,7 +458,7 @@ py::array Tensor::data_sync() { return data_; } -REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { // dtype should define before Tensor, because Tensor init depend dtype (void)py::class_>(*m, "Tensor") .def(py::init(), py::arg("dtype"), py::arg("shape")) @@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { .def("__repr__", &Tensor::ToStringRepr) .def("__eq__", &Tensor::ValueEqualPy) .def(py::pickle( - [](const Tensor& t) { // __getstate__ + [](const Tensor &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(t.data()); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } diff --git a/mindspore/ccsrc/ir/meta_tensor.h b/mindspore/ccsrc/ir/meta_tensor.h index 3e28f29f37c..1f6c866f11c 100644 --- a/mindspore/ccsrc/ir/meta_tensor.h +++ b/mindspore/ccsrc/ir/meta_tensor.h @@ -131,16 +131,16 @@ class MetaTensor : public Value { // information of a Tensor. The following codes will create a 2x3 float // param data_type The data type of the tensor. // param shape The shape of the tensor. - MetaTensor(const TypeId data_type, const std::vector& shape); + MetaTensor(const TypeId data_type, const std::vector &shape); - MetaTensor(const TypePtr& type_ptr, const py::tuple& shape); + MetaTensor(const TypePtr &type_ptr, const py::tuple &shape); // brief Constructs a MetaTensor object from an existing MetaTensor instance. // // The constructed MetaTensor object will have the same data type and shape as the // meta_tensor. // // param meta_tensor An existing MetaTensor object. - MetaTensor(const MetaTensor& meta_tensor); + MetaTensor(const MetaTensor &meta_tensor); ~MetaTensor() override = default; MS_DECLARE_PARENT(MetaTensor, Value) @@ -149,7 +149,7 @@ class MetaTensor : public Value { // The constructed MetaTensor object has the same type and shape with meta_tensor. // // param meta_tensor An existing MetaTensor object. - virtual MetaTensor& operator=(const MetaTensor& meta_tensor); + virtual MetaTensor &operator=(const MetaTensor &meta_tensor); // brief Compares two MetaTensor objects. // @@ -157,7 +157,7 @@ class MetaTensor : public Value { // // param meta_tensor The MetaTensor object to be compared. // return true: If having same type and shape, return true, or return false. - virtual bool operator==(const MetaTensor& meta_tensor) const; + virtual bool operator==(const MetaTensor &meta_tensor) const; // brief Returns the data type of the tensor in its MetaTensor. // @@ -193,7 +193,7 @@ class MetaTensor : public Value { // // param shape The shape of the tensor. // return The shape's size. - size_t set_shape(const std::vector& shape) { + size_t set_shape(const std::vector &shape) { this->shape_ = shape; return shape_.size(); } @@ -202,9 +202,9 @@ class MetaTensor : public Value { DeviceInfo device_info() const { return device_info_; } // Set tensor's device info. - void set_device_info(const DeviceInfo& device_info) { device_info_ = device_info; } + void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } - void SetDeviceInfo(const std::string& format, const TypePtr& data_type); + void SetDeviceInfo(const std::string &format, const TypePtr &data_type); // Get the size of a given dimension by its index number. int DimensionSize(size_t index) const; @@ -222,9 +222,9 @@ class MetaTensor : public Value { } return hash_value; } - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -262,49 +262,49 @@ class Tensor : public MetaTensor { // // param type_ptr [TypePty] Data type of the tensor. // param py_shape [py::tuple] The shape represented by py::tuple of the tensor. - Tensor(const TypePtr& type_ptr, const py::tuple& shape); + Tensor(const TypePtr &type_ptr, const py::tuple &shape); // brief Constructor for C++. // // param data_type [TypeId] Data type of the tensor. // param shape The shape represented by std::vector of the tensor. - Tensor(TypeId data_type, const std::vector& shape); + Tensor(TypeId data_type, const std::vector &shape); // brief Constructor for Python. // // param input [py::array] Data value of the tensor. // param data_type [TypeId] Data type of the tensor. - explicit Tensor(const py::array& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::list] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::list& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::tuple] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::tuple& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::float_] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::float_& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::int_] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::int_& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [Tensor] the data for tensor // param data_type [TypeId] data type - Tensor(const Tensor& tensor, const TypePtr& data_type = nullptr); + Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr); ~Tensor() override = default; @@ -315,7 +315,7 @@ class Tensor : public MetaTensor { // The constructed Tensor object has the same type and shape with tensor. // // param tensor An existing Tensor object. - Tensor& operator=(const Tensor& tensor); + Tensor &operator=(const Tensor &tensor); // brief Compares two Tensor objects. // @@ -324,17 +324,17 @@ class Tensor : public MetaTensor { // // param tensor The Tensor object to be compared. // return true: If having same type, shape and data, return true, or return false. - bool operator==(const Tensor& tensor) const; + bool operator==(const Tensor &tensor) const; // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. - bool ValueEqual(const Tensor& other) const; + bool ValueEqual(const Tensor &other) const; // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. - bool ValueEqualPy(const py::object& other) const; + bool ValueEqualPy(const py::object &other) const; - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -375,13 +375,13 @@ class Tensor : public MetaTensor { // // param writable true if writable, false if read only // return The pointer to the object - void* data_c(bool writable = false); + void *data_c(bool writable = false); // brief Get data type from tensor data. // // param buf The buffer info of the py::array data. // return The [TypeId] of the tensor data. - TypeId GetDataType(const py::buffer_info& buf) const; + TypeId GetDataType(const py::buffer_info &buf) const; // brief Sets the data type of a tensor. // @@ -401,23 +401,23 @@ class Tensor : public MetaTensor { // param input [py::array] the data for tensor // param data_type [TypeId] data type // return true if succeed, false if failed. - void init(const py::array& input, const TypeId& data_type); - void init(const py::array& input, const TypePtr& type_ptr); + void init(const py::array &input, const TypeId &data_type); + void init(const py::array &input, const TypePtr &type_ptr); // brief init tensor attribute // // param data_type [TypeId] Data type of the tensor. // param shape [py::array] The shape of the tensor. // return true if succeed, false if failed. - void init(TypeId data_type, const std::vector& shape, py::array* data); + void init(TypeId data_type, const std::vector &shape, py::array *data); - bool convert_data(const py::array& in, const TypeId in_data_type, py::array* out, const TypeId out_data_type); + bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type); public: bool is_dirty() const { return dirty_; } void set_dirty(const bool dirty) { dirty_ = dirty; } DeviceAddressPtr device_address() const { return device_address_; } - void set_device_address(const DeviceAddressPtr& device_address) { device_address_ = device_address; } + void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } py::array data_sync(); private: diff --git a/mindspore/ccsrc/ir/named.cc b/mindspore/ccsrc/ir/named.cc index 3d12e8a4536..67e11c64d33 100644 --- a/mindspore/ccsrc/ir/named.cc +++ b/mindspore/ccsrc/ir/named.cc @@ -18,9 +18,9 @@ #include "pipeline/static_analysis/abstract_value.h" namespace mindspore { -bool Named::operator==(const Value& other) const { +bool Named::operator==(const Value &other) const { if (other.isa()) { - auto other_named = static_cast(other); + auto other_named = static_cast(other); return *this == other_named; } else { return false; diff --git a/mindspore/ccsrc/ir/named.h b/mindspore/ccsrc/ir/named.h index 0651307a91b..76136fb2987 100644 --- a/mindspore/ccsrc/ir/named.h +++ b/mindspore/ccsrc/ir/named.h @@ -27,18 +27,18 @@ namespace mindspore { class Named : public Value { public: - explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash{}(name); } - Named(const Named& other) : Value(other) { + explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash{}(name); } + Named(const Named &other) : Value(other) { this->name_ = other.name_; hash_id_ = std::hash{}(other.name_); } ~Named() override = default; MS_DECLARE_PARENT(Named, Value); - const std::string& name() const { return name_; } - virtual bool operator==(const Named& other) const { return name_ == other.name(); } - bool operator==(const Value& other) const override; - Named& operator=(const Named& other) { + const std::string &name() const { return name_; } + virtual bool operator==(const Named &other) const { return name_ == other.name(); } + bool operator==(const Value &other) const override; + Named &operator=(const Named &other) { if (&other != this) { this->type_ = other.type_; this->name_ = other.name_; @@ -50,7 +50,7 @@ class Named : public Value { std::size_t Hash() const { return hash_id_; } std::size_t hash() const override { return hash_id_; } - friend std::ostream& operator<<(std::ostream& os, const Named& nmd) { + friend std::ostream &operator<<(std::ostream &os, const Named &nmd) { os << nmd.name(); return os; } diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index a576c1e76b8..d40f8a265d3 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -31,7 +31,7 @@ namespace mindspore { using mindspore::abstract::AbstractFunction; -abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr& anf_node) { +abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { auto prim_func = std::make_shared(shared_from_base(), anf_node); return prim_func; } @@ -63,23 +63,23 @@ py::function Primitive::GetComputeFunction() { return fn; } -bool Primitive::operator==(const Value& other) const { +bool Primitive::operator==(const Value &other) const { if (other.isa()) { - auto other_prim = static_cast(other); + auto other_prim = static_cast(other); return *this == other_prim; } else { return false; } } -bool Primitive::operator==(const Primitive& other) const { +bool Primitive::operator==(const Primitive &other) const { if (name() != other.name()) { return false; } if (attrs_.size() != other.attrs_.size()) { return false; } - auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair& item) -> bool { + auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { if (item.second == nullptr) { return false; } @@ -95,7 +95,7 @@ bool Primitive::operator==(const Primitive& other) const { void Primitive::set_signatures( std::vector> signatures) { signatures_.clear(); - for (auto& signature : signatures) { + for (auto &signature : signatures) { std::string name; SignatureEnumRW rw; SignatureEnumKind kind; @@ -114,7 +114,7 @@ std::string Primitive::GetAttrsText() const { std::ostringstream oss; oss << "["; bool is_first = true; - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { if (is_first) { is_first = false; } else { @@ -128,7 +128,7 @@ std::string Primitive::GetAttrsText() const { } py::function PrimitivePy::GetBpropFunction() { - static const char* const get_bprop_func_name = "get_bprop"; + static const char *const get_bprop_func_name = "get_bprop"; if (py::hasattr(python_obj_, get_bprop_func_name)) { py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); return fn; @@ -142,7 +142,7 @@ py::function PrimitivePy::GetBpropFunction() { } py::function PrimitivePy::GetComputeFunction() { - static const char* const compute_func_name = "vm_impl"; + static const char *const compute_func_name = "vm_impl"; if (py::hasattr(python_obj_, compute_func_name)) { MS_LOG(INFO) << "" << name() << " compute_func_name"; @@ -163,7 +163,7 @@ py::function PrimitivePy::GetComputeFunction() { return vm_fn; } -void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { +void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { std::string attr_name = name; ValuePtr converted_ret = nullptr; if (py::isinstance(obj)) { @@ -178,13 +178,13 @@ void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { py::dict PrimitivePy::GetAttrDict() { py::dict attr_dict; - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); } return attr_dict; } -REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { (void)py::enum_(*m, "prim_type", py::arithmetic()) .value("unknown", PrimType::kPrimTypeUnknown) .value("builtin", PrimType::kPrimTypeBuiltIn) @@ -192,7 +192,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { .value("user_custom", PrimType::kPrimTypeUserCustom); (void)py::class_>(*m, "Primitive_") .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) - .def(py::init()) + .def(py::init()) .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index 7dd37eb15ff..d16a524f692 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -48,25 +48,25 @@ enum PrimType { class Primitive : public Named { public: - explicit Primitive(const std::string& name, const PrimType prim_type = kPrimTypeBuiltIn) + explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn) : Named(name), signatures_(), prim_type_(prim_type) {} - Primitive(const Primitive& prim) + Primitive(const Primitive &prim) : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} MS_DECLARE_PARENT(Primitive, Named); - abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr& anf_node); + abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); std::string ToString() const override { return name(); } virtual py::function GetBpropFunction(); virtual py::function GetComputeFunction(); - Primitive& AddAttr(const std::string& name, const ValuePtr& attr) { + Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; return *this; } - Primitive& SetAttrs(const std::unordered_map& attrs) { - for (auto& attr : attrs) { + Primitive &SetAttrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { attrs_[attr.first] = attr.second; } return *this; @@ -76,21 +76,21 @@ class Primitive : public Named { std::vector> signatures); - const std::vector& signatures() const { return signatures_; } + const std::vector &signatures() const { return signatures_; } - void set_attr(const std::string& attrName, const ValuePtr& attr) { attrs_[attrName] = attr; } - void EraseAttr(const std::string& attrName) { (void)attrs_.erase(attrName); } + void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } + void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } - ValuePtr GetAttr(const std::string& attrName) const { + ValuePtr GetAttr(const std::string &attrName) const { auto iter = attrs_.find(attrName); return iter == attrs_.cend() ? nullptr : iter->second; } - const std::unordered_map& attrs() const { return attrs_; } + const std::unordered_map &attrs() const { return attrs_; } // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. bool HasAttr() const { return !attrs_.empty(); } - bool HasAttr(const std::string& attrName) const { + bool HasAttr(const std::string &attrName) const { auto iter = attrs_.find(attrName); return !(iter == attrs_.cend()); } @@ -103,8 +103,8 @@ class Primitive : public Named { PrimType prim_type() const { return prim_type_; } std::string instance_name() const { return instance_name_; } std::string GetAttrsText() const; - bool operator==(const Value& other) const override; - bool operator==(const Primitive& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Primitive &other) const; ~Primitive() override = default; protected: @@ -118,18 +118,18 @@ class Primitive : public Named { class PrimitivePy : public Primitive { public: - PrimitivePy(const py::str& name, const py::object& python_obj) : Primitive(name), python_obj_(python_obj) {} + PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {} ~PrimitivePy() override = default; MS_DECLARE_PARENT(PrimitivePy, Primitive); py::function GetBpropFunction() override; py::function GetComputeFunction() override; - void AddPyAttr(const py::str& name, const py::object& obj); + void AddPyAttr(const py::str &name, const py::object &obj); py::dict GetAttrDict(); const bool parse_info_ = true; - const py::object& GetPyObj() const { return python_obj_; } + const py::object &GetPyObj() const { return python_obj_; } bool is_tuple_input_ = false; private: @@ -138,13 +138,13 @@ class PrimitivePy : public Primitive { using PrimitivePyPtr = std::shared_ptr; -inline std::ostream& operator<<(std::ostream& os, const PrimitivePtr& p) { +inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { os << *p; return os; } struct PrimitiveEqual { - bool operator()(PrimitivePtr const& t1, PrimitivePtr const& t2) const { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t2); return t1->name() == t2->name(); @@ -152,7 +152,7 @@ struct PrimitiveEqual { }; struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const& prim) const { + std::size_t operator()(PrimitivePtr const &prim) const { std::size_t hash = std::hash()(prim->name()); return hash; } diff --git a/mindspore/ccsrc/ir/scalar.h b/mindspore/ccsrc/ir/scalar.h index 3e0a827b072..ab6c485540a 100644 --- a/mindspore/ccsrc/ir/scalar.h +++ b/mindspore/ccsrc/ir/scalar.h @@ -55,8 +55,8 @@ class BoolImm : public Scalar { bool value() const { return v_; } bool IsZero() override { return v_ == false; } bool IsOne() override { return v_ == true; } - bool operator==(const Value& other) const override; - bool operator==(const BoolImm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const BoolImm &other) const; std::string ToString() const override { if (v_) { return "true"; @@ -80,7 +80,7 @@ IMM_TRAITS(BoolImmPtr, bool) class IntergerImm : public Scalar { public: IntergerImm() = default; - explicit IntergerImm(const TypePtr& t) : Scalar(t) {} + explicit IntergerImm(const TypePtr &t) : Scalar(t) {} ~IntergerImm() override = default; MS_DECLARE_PARENT(IntergerImm, Scalar) }; @@ -95,8 +95,8 @@ class Int8Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int8_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int8Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int8Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -121,8 +121,8 @@ class Int16Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int16_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int16Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int16Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -147,8 +147,8 @@ class Int32Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int32_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -173,8 +173,8 @@ class Int64Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int64_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -199,8 +199,8 @@ class UInt8Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint8_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt8Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt8Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -225,8 +225,8 @@ class UInt16Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint16_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt16Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt16Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -251,8 +251,8 @@ class UInt32Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint32_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -277,8 +277,8 @@ class UInt64Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint64_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -296,7 +296,7 @@ IMM_TRAITS(UInt64ImmPtr, uint64_t); class FloatImm : public Scalar { public: FloatImm() = default; - explicit FloatImm(const TypePtr& t) : Scalar(t) {} + explicit FloatImm(const TypePtr &t) : Scalar(t) {} ~FloatImm() override = default; MS_DECLARE_PARENT(FloatImm, Scalar) }; @@ -312,8 +312,8 @@ class FP32Imm : public FloatImm { bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } float value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const FP32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const FP32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -338,8 +338,8 @@ class FP64Imm : public FloatImm { bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } double value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const FP64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const FP64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { diff --git a/mindspore/ccsrc/ir/signature.cc b/mindspore/ccsrc/ir/signature.cc index b7eec921d45..8f312d5b981 100644 --- a/mindspore/ccsrc/ir/signature.cc +++ b/mindspore/ccsrc/ir/signature.cc @@ -21,8 +21,8 @@ #include "pipeline/parse/data_converter.h" namespace mindspore { -Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, - const py::object& arg_default, const SignatureEnumDType& arg_dtype) +Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, + const py::object &arg_default, const SignatureEnumDType &arg_dtype) : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { if (py::isinstance(arg_default) && py::cast(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { @@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, } } -Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind) +Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) : name(arg_name), rw(rw_tag), kind(arg_kind), default_value(nullptr), dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} -REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { (void)py::enum_(*m, "signature_rw", py::arithmetic()) .value("RW_READ", SignatureEnumRW::kRWRead) .value("RW_WRITE", SignatureEnumRW::kRWWrite) diff --git a/mindspore/ccsrc/ir/signature.h b/mindspore/ccsrc/ir/signature.h index 8e7409ab264..48be7e0f315 100644 --- a/mindspore/ccsrc/ir/signature.h +++ b/mindspore/ccsrc/ir/signature.h @@ -61,9 +61,9 @@ struct Signature { SignatureEnumKind kind; ValuePtr default_value; // nullptr for no default value SignatureEnumDType dtype; - Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, - const py::object& arg_default, const SignatureEnumDType& arg_dtype); - Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind); + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, + const py::object &arg_default, const SignatureEnumDType &arg_dtype); + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind); }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/value.cc b/mindspore/ccsrc/ir/value.cc index f9e8abaee9d..e386e1ffd2b 100644 --- a/mindspore/ccsrc/ir/value.cc +++ b/mindspore/ccsrc/ir/value.cc @@ -24,7 +24,7 @@ #include "pipeline/static_analysis/abstract_value.h" namespace mindspore { -const ValuePtr ValueSequeue::operator[](const std::size_t& dim) const { +const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "]."; } @@ -40,125 +40,125 @@ bool ValueSequeue::erase(size_t idx) { } } -bool BoolImm::operator==(const Value& other) const { +bool BoolImm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool BoolImm::operator==(const BoolImm& other) const { return v_ == other.v_; } +bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; } -bool Int8Imm::operator==(const Value& other) const { +bool Int8Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int8Imm::operator==(const Int8Imm& other) const { return v_ == other.v_; } -bool Int16Imm::operator==(const Value& other) const { +bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; } +bool Int16Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int16Imm::operator==(const Int16Imm& other) const { return v_ == other.v_; } -bool Int32Imm::operator==(const Value& other) const { +bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; } +bool Int32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int32Imm::operator==(const Int32Imm& other) const { return v_ == other.v_; } -bool Int64Imm::operator==(const Value& other) const { +bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; } +bool Int64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int64Imm::operator==(const Int64Imm& other) const { return v_ == other.v_; } -bool UInt8Imm::operator==(const Value& other) const { +bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; } +bool UInt8Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt8Imm::operator==(const UInt8Imm& other) const { return v_ == other.v_; } -bool UInt16Imm::operator==(const Value& other) const { +bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; } +bool UInt16Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt16Imm::operator==(const UInt16Imm& other) const { return v_ == other.v_; } -bool UInt32Imm::operator==(const Value& other) const { +bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; } +bool UInt32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt32Imm::operator==(const UInt32Imm& other) const { return v_ == other.v_; } -bool UInt64Imm::operator==(const Value& other) const { +bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; } +bool UInt64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt64Imm::operator==(const UInt64Imm& other) const { return v_ == other.v_; } -bool FP32Imm::operator==(const Value& other) const { +bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; } +bool FP32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool FP32Imm::operator==(const FP32Imm& other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } -bool FP64Imm::operator==(const Value& other) const { +bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } +bool FP64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSequeue::operator==(const Value& other) const { +bool ValueSequeue::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSequeue::operator==(const ValueSequeue& other) const { +bool ValueSequeue::operator==(const ValueSequeue &other) const { if (other.elements_.size() != elements_.size()) { return false; } return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(), - [](const ValuePtr& lhs, const ValuePtr& rhs) { return *lhs == *rhs; }); + [](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; }); } std::string ValueSequeue::ToString() const { std::ostringstream buffer; bool begin = true; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -179,28 +179,28 @@ std::string ValueSequeue::DumpText() const { return oss.str(); } -bool FP64Imm::operator==(const FP64Imm& other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } -bool StringImm::operator==(const Value& other) const { +bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } +bool StringImm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool StringImm::operator==(const StringImm& other) const { return str_ == other.str_; } +bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } -bool RefKey::operator==(const Value& other) const { +bool RefKey::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool RefKey::operator==(const RefKey& other) const { return tag_ == other.tag_; } +bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; } -bool AnyValue::operator==(const Value& other) const { +bool AnyValue::operator==(const Value &other) const { if (other.isa()) { return true; } else { @@ -228,7 +228,7 @@ abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_sharedToAbstract(); }); @@ -237,7 +237,7 @@ abstract::AbstractBasePtr ValueTuple::ToAbstract() { abstract::AbstractBasePtr ValueList::ToAbstract() { abstract::AbstractBasePtrList a_list; - (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { + (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { MS_EXCEPTION_IF_NULL(ele); return ele->ToAbstract(); }); @@ -251,16 +251,16 @@ std::size_t ValueSlice::hash() const { return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); } -bool ValueSlice::operator==(const Value& other) const { +bool ValueSlice::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSlice::operator==(const ValueSlice& other) const { +bool ValueSlice::operator==(const ValueSlice &other) const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); @@ -295,16 +295,16 @@ std::size_t KeywordArg::hash() const { return hash_combine({tid(), std::hash{}(key_), value_->hash()}); } -bool KeywordArg::operator==(const Value& other) const { +bool KeywordArg::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool KeywordArg::operator==(const KeywordArg& other) const { return (other.key_ == key_ && *other.value_ == *value_); } +bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); } std::string KeywordArg::ToString() const { std::ostringstream buffer; @@ -322,25 +322,25 @@ abstract::AbstractBasePtr KeywordArg::ToAbstract() { return std::make_shared(key_, argument); } -const ValuePtr ValueDictionary::operator[](const std::string& key) const { +const ValuePtr ValueDictionary::operator[](const std::string &key) const { auto it = std::find_if(key_values_.begin(), key_values_.end(), - [key](const std::pair& item) { return item.first == key; }); + [key](const std::pair &item) { return item.first == key; }); if (it == key_values_.end()) { MS_LOG(EXCEPTION) << "The key " << key << " is not in the map"; } return it->second; } -bool ValueDictionary::operator==(const Value& other) const { +bool ValueDictionary::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueDictionary::operator==(const ValueDictionary& other) const { +bool ValueDictionary::operator==(const ValueDictionary &other) const { if (key_values_.size() != other.key_values_.size()) { return false; } @@ -359,12 +359,12 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() { std::vector> kv; (void)std::transform( key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const std::pair& item) { return std::make_pair(item.first, item.second->ToAbstract()); }); + [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); return std::make_shared(kv); } REGISTER_PYBIND_DEFINE( - RefKey, ([](const py::module* m) { + RefKey, ([](const py::module *m) { (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); })); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/value.h b/mindspore/ccsrc/ir/value.h index 85f514b57b3..c80e22f735f 100644 --- a/mindspore/ccsrc/ir/value.h +++ b/mindspore/ccsrc/ir/value.h @@ -35,19 +35,19 @@ namespace mindspore { class ValueSequeue : public Value { public: - explicit ValueSequeue(const ValuePtrList& elements) : elements_(elements) { + explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) { TypePtrList t_list; - (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr& ele) { + (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) { MS_EXCEPTION_IF_NULL(ele); return ele->type(); }); TypePtr t = std::make_shared(t_list); type_ = t; } - ValueSequeue(const std::initializer_list& elements) : elements_(elements.begin(), elements.end()) { + ValueSequeue(const std::initializer_list &elements) : elements_(elements.begin(), elements.end()) { TypePtrList t_list; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), - [](const ValuePtr& ele) { return ele->type(); }); + [](const ValuePtr &ele) { return ele->type(); }); TypePtr t = std::make_shared(t_list); type_ = t; } @@ -56,10 +56,10 @@ class ValueSequeue : public Value { std::size_t hash() const override { return hash_combine(tid(), std::hash{}(elements_.size())); } std::size_t size() const { return elements_.size(); } bool erase(size_t idx); - const ValuePtr operator[](const std::size_t& dim) const; - const ValuePtrList& value() const { return elements_; } - bool operator==(const Value& other) const override; - bool operator==(const ValueSequeue& other) const; + const ValuePtr operator[](const std::size_t &dim) const; + const ValuePtrList &value() const { return elements_; } + bool operator==(const Value &other) const override; + bool operator==(const ValueSequeue &other) const; std::string ToString() const override; std::string DumpText() const override; @@ -70,8 +70,8 @@ using ValueSequeuePtr = std::shared_ptr; class ValueTuple : public ValueSequeue { public: - explicit ValueTuple(const std::vector& elements) : ValueSequeue(elements) {} - ValueTuple(const std::initializer_list& elements) : ValueSequeue(elements) {} + explicit ValueTuple(const std::vector &elements) : ValueSequeue(elements) {} + ValueTuple(const std::initializer_list &elements) : ValueSequeue(elements) {} ~ValueTuple() override = default; MS_DECLARE_PARENT(ValueTuple, ValueSequeue) abstract::AbstractBasePtr ToAbstract() override; @@ -83,8 +83,8 @@ using ValueTuplePtr = std::shared_ptr; class ValueList : public ValueSequeue { public: - explicit ValueList(const std::vector& elements) : ValueSequeue(elements) {} - ValueList(const std::initializer_list& elements) : ValueSequeue(elements) {} + explicit ValueList(const std::vector &elements) : ValueSequeue(elements) {} + ValueList(const std::initializer_list &elements) : ValueSequeue(elements) {} ~ValueList() override = default; MS_DECLARE_PARENT(ValueList, ValueSequeue) abstract::AbstractBasePtr ToAbstract() override; @@ -94,7 +94,7 @@ class ValueList : public ValueSequeue { }; using ValueListPtr = std::shared_ptr; -inline ValuePtr MakeValue(const std::vector& v) { return std::make_shared(v); } +inline ValuePtr MakeValue(const std::vector &v) { return std::make_shared(v); } inline ValuePtr MakeValue(std::initializer_list v) { return std::make_shared(v); } template @@ -103,7 +103,7 @@ template struct is_vector> : public std::true_type {}; template ::value, typename T::value_type>::type> -ValuePtr MakeValue(const T& vec) { +ValuePtr MakeValue(const T &vec) { std::vector list; (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); return std::make_shared(list); @@ -111,13 +111,13 @@ ValuePtr MakeValue(const T& vec) { class ValueSlice : public Value { public: - ValueSlice(const ValuePtr& start, const ValuePtr& stop, const ValuePtr& step) + ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step) : start_(start), stop_(stop), step_(step) {} ~ValueSlice() override = default; MS_DECLARE_PARENT(ValueSlice, Value) std::size_t hash() const override; - bool operator==(const Value& other) const override; - bool operator==(const ValueSlice& other) const; + bool operator==(const Value &other) const override; + bool operator==(const ValueSlice &other) const; std::string ToString() const override; @@ -133,13 +133,13 @@ using ValueSlicePtr = std::shared_ptr; class KeywordArg : public Value { public: - KeywordArg(const std::string& key, const ValuePtr& value) : key_(key), value_(value) {} + KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {} ~KeywordArg() override = default; MS_DECLARE_PARENT(KeywordArg, Value) std::size_t hash() const override; ValuePtr get_value() const { return value_; } - bool operator==(const Value& other) const override; - bool operator==(const KeywordArg& other) const; + bool operator==(const Value &other) const override; + bool operator==(const KeywordArg &other) const; std::string ToString() const override; @@ -154,31 +154,31 @@ using KeywordArgPtr = std::shared_ptr; class ValueDictionary : public Value { public: - explicit ValueDictionary(const std::vector>& key_values) : key_values_(key_values) {} + explicit ValueDictionary(const std::vector> &key_values) : key_values_(key_values) {} ~ValueDictionary() override = default; MS_DECLARE_PARENT(ValueDictionary, Value) std::size_t hash() const override { return hash_combine(tid(), std::hash{}(key_values_.size())); } std::size_t size() const { return key_values_.size(); } - const ValuePtr operator[](const std::string& key) const; - const std::vector>& value() const { return key_values_; } - bool operator==(const Value& other) const override; - bool operator==(const ValueDictionary& other) const; + const ValuePtr operator[](const std::string &key) const; + const std::vector> &value() const { return key_values_; } + bool operator==(const Value &other) const override; + bool operator==(const ValueDictionary &other) const; std::string ToString() const override { std::ostringstream buffer; std::vector keys; std::vector values; - for (const auto& kv : key_values_) { + for (const auto &kv : key_values_) { keys.push_back(kv.first); values.push_back(kv.second); } buffer << "(Dict: " << " keys:("; - for (const auto& key : keys) { + for (const auto &key : keys) { buffer << key << ", "; } buffer << ") values:("; - for (const auto& value : values) { + for (const auto &value : values) { MS_EXCEPTION_IF_NULL(value); buffer << value->DumpText() << ", "; } @@ -195,14 +195,14 @@ using ValueDictionaryPtr = std::shared_ptr; class StringImm : public Value { public: - explicit StringImm(const std::string& str) : Value(kString), str_(str), hash_(std::hash{}(str_)) {} + explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash{}(str_)) {} ~StringImm() override = default; MS_DECLARE_PARENT(StringImm, Value) std::size_t hash() const override { return hash_; } - const std::string& value() const { return str_; } - bool operator==(const Value& other) const override; - bool operator==(const StringImm& other) const; + const std::string &value() const { return str_; } + bool operator==(const Value &other) const override; + bool operator==(const StringImm &other) const; abstract::AbstractBasePtr ToAbstract() override; std::string ToString() const override { return str_; } @@ -218,18 +218,18 @@ class StringImm : public Value { }; using StringImmPtr = std::shared_ptr; IMM_TRAITS(StringImmPtr, std::string) -IMM_TRAITS(StringImmPtr, const char*) +IMM_TRAITS(StringImmPtr, const char *) class RefKey : public Value { public: - explicit RefKey(const std::string& tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} + explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} ~RefKey() override = default; MS_DECLARE_PARENT(RefKey, Value) std::size_t hash() const override { return hash_; } - const std::string& tag() const { return tag_; } - bool operator==(const Value& other) const override; - bool operator==(const RefKey& other) const; + const std::string &tag() const { return tag_; } + bool operator==(const Value &other) const override; + bool operator==(const RefKey &other) const; abstract::AbstractBasePtr ToAbstract() override; std::string ToString() const override { return "RefKey[" + tag_ + "]"; } @@ -251,13 +251,13 @@ class AnyValue : public Value { ~AnyValue() override = default; MS_DECLARE_PARENT(AnyValue, Value) std::size_t hash() const override { return tid(); } - bool operator==(const Value& other) const override; + bool operator==(const Value &other) const override; abstract::AbstractBasePtr ToAbstract() override; }; extern const ValuePtr kAnyValue; template <> -inline const char* GetValue(const ValuePtr& value) { +inline const char *GetValue(const ValuePtr &value) { if (value == nullptr) { MS_LOG(EXCEPTION) << "Value is nullptr"; } @@ -270,7 +270,7 @@ inline const char* GetValue(const ValuePtr& value) { template ::type, typename U = typename std::enable_if::value, typename S::value_type>::type> -std::vector GetValue(const ValuePtr& value) { +std::vector GetValue(const ValuePtr &value) { if (value == nullptr) { MS_LOG(EXCEPTION) << "Value is nullptr"; } @@ -280,21 +280,21 @@ std::vector GetValue(const ValuePtr& value) { << ">"; } std::vector rets; - const std::vector& vals = value->cast()->value(); + const std::vector &vals = value->cast()->value(); (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), - [](const ValuePtr& v) { return GetValue(v); }); + [](const ValuePtr &v) { return GetValue(v); }); return rets; } -inline ValueNodePtr NewValueNode(const ValuePtr& t) { return std::make_shared(t); } +inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared(t); } template ::value>::type> -inline ValueNodePtr NewValueNode(const std::shared_ptr& x) { +inline ValueNodePtr NewValueNode(const std::shared_ptr &x) { return NewValueNode(MakeValue(x)); } template ::value>::type> -inline ValueNodePtr NewValueNode(const T& x) { +inline ValueNodePtr NewValueNode(const T &x) { return NewValueNode(MakeValue(x)); } } // namespace mindspore diff --git a/mindspore/ccsrc/ir/visitor.h b/mindspore/ccsrc/ir/visitor.h index 5305d1fe858..e771f7ad28b 100644 --- a/mindspore/ccsrc/ir/visitor.h +++ b/mindspore/ccsrc/ir/visitor.h @@ -22,15 +22,15 @@ #include "optimizer/opt.h" namespace mindspore { -using VisitFuncType = std::function; +using VisitFuncType = std::function; class AnfVisitor { public: - virtual AnfNodePtr operator()(const opt::OptimizerPtr&, const AnfNodePtr&); - virtual void Visit(const AnfNodePtr&); - virtual void Visit(const CNodePtr&); - virtual void Visit(const ValueNodePtr&); - virtual void Visit(const ParameterPtr&); - VisitFuncType Match(const PrimitivePtr&, const std::vector& = {}); + virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &); + virtual void Visit(const AnfNodePtr &); + virtual void Visit(const CNodePtr &); + virtual void Visit(const ValueNodePtr &); + virtual void Visit(const ParameterPtr &); + VisitFuncType Match(const PrimitivePtr &, const std::vector & = {}); virtual ~AnfVisitor() = default; }; } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 7934bd0a5c4..3d3282e7b55 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -26,12 +26,12 @@ namespace mindspore { namespace kernel { namespace { -void FilterInvaildKernelInfo(const CNodePtr& kernel_node, - std::vector>* kernel_info_list) { +void FilterInvaildKernelInfo(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_info_list); std::vector> filtered_list; (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), - [&](const std::shared_ptr& kernel_build_info) { + [&](const std::shared_ptr &kernel_build_info) { return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); }); @@ -46,7 +46,7 @@ void FilterInvaildKernelInfo(const CNodePtr& kernel_node, } } } // namespace -void KernelQuery(const CNodePtr& kernel_node, std::vector>* kernel_info_list) { +void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); TbeMetadataInfo(kernel_node, kernel_info_list); diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h index 215df217766..670830a8b18 100644 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ b/mindspore/ccsrc/kernel/oplib/opinfo.h @@ -38,11 +38,11 @@ class OpAttr { std::string value() const { return value_; } std::string default_value() const { return default_value_; } - void set_name(const std::string& name) { name_ = name; } - void set_param_type(const std::string& param_type) { param_type_ = param_type; } - void set_type(const std::string& type) { type_ = type; } - void set_value(const std::string& value) { value_ = value; } - void set_default_value(const std::string& default_value) { default_value_ = default_value; } + void set_name(const std::string &name) { name_ = name; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_type(const std::string &type) { type_ = type; } + void set_value(const std::string &value) { value_ = value; } + void set_default_value(const std::string &default_value) { default_value_ = default_value; } private: std::string name_; @@ -67,13 +67,13 @@ class OpIOInfo { std::vector formats() const { return formats_; } void set_index(const int index) { index_ = index; } - void set_name(const std::string& name) { name_ = name; } + void set_name(const std::string &name) { name_ = name; } void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } - void set_param_type(const std::string& param_type) { param_type_ = param_type; } - void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; } - void set_shape(const std::string& shape) { shape_ = shape; } - void set_dtypes(const std::vector& dtype) { dtypes_ = dtype; } - void set_formats(const std::vector& formats) { formats_ = formats; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } + void set_shape(const std::string &shape) { shape_ = shape; } + void set_dtypes(const std::vector &dtype) { dtypes_ = dtype; } + void set_formats(const std::vector &formats) { formats_ = formats; } private: int index_ = 0; @@ -104,24 +104,24 @@ class OpInfo { std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } std::vector> outputs_ptr() const { return outputs_ptr_; } - const std::unordered_map& ref_infos() const { return ref_infos_; } + const std::unordered_map &ref_infos() const { return ref_infos_; } - void set_op_name(const std::string& op_name) { op_name_ = op_name; } + void set_op_name(const std::string &op_name) { op_name_ = op_name; } void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } - void set_impl_path(const std::string& impl_path) { impl_path_ = impl_path; } - void set_fusion_type(const std::string& fusion_type) { fusion_type_ = fusion_type; } + void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } + void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } - void set_binfile_name(const std::string& binfile_name) { binfile_name_ = binfile_name; } + void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } - void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } + void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } - void add_attrs_ptr(const std::shared_ptr& attr) { attrs_ptr_.push_back(attr); } - void add_inputs_ptr(const std::shared_ptr& input) { inputs_ptr_.push_back(input); } - void add_outputs_ptr(const std::shared_ptr& output) { outputs_ptr_.push_back(output); } - void set_inputs_ptr(const std::vector>& inputs) { inputs_ptr_ = inputs; } - void set_outputs_ptr(const std::vector>& outputs) { outputs_ptr_ = outputs; } + void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } + void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } + void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } + void set_inputs_ptr(const std::vector> &inputs) { inputs_ptr_ = inputs; } + void set_outputs_ptr(const std::vector> &outputs) { outputs_ptr_ = outputs; } bool is_ref() const { return !ref_infos_.empty(); } bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index c8cc1530ce3..cd0f8438676 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -67,7 +67,7 @@ std::string ImplTypeToStr(OpImplyType impl_type) { return "unknow"; } } -bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) { +bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { bool ret = false; try { auto op_json = nlohmann::json::parse(json_string); @@ -88,13 +88,13 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) if (!ret) { MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string; } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(DEBUG) << "get op_json elements failed:" << e.what(); } return ret; } -void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr& op_info) { +void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { op_info->set_async_flag(obj.at(kAsyncFlag)); op_info->set_binfile_name(obj.at(kBinfileName)); op_info->set_compute_cost(obj.at(kComputeCost)); @@ -108,8 +108,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_p } } -bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type, - const std::string& impl_path) { +bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, + const std::string &impl_path) { std::shared_ptr op_info = std::make_shared(); MS_EXCEPTION_IF_NULL(op_info); op_info->set_op_name(obj.at(kOpName)); @@ -120,7 +120,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI DecodeTBESpecificInfo(obj, op_info); } auto attrs = obj.at(kAttr); - for (const auto& attr : attrs) { + for (const auto &attr : attrs) { if (!DecodeAttr(attr, imply_type, op_info)) { MS_LOG(DEBUG) << "DecodeAttr Failed"; return false; @@ -131,14 +131,14 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI dtype_format = obj.at(kDtypeFormat); } auto inputs = obj.at(kIputs); - for (const auto& input : inputs) { + for (const auto &input : inputs) { if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { MS_LOG(DEBUG) << "DecodeInputOutput Failed"; return false; } } auto outputs = obj.at(kOutputs); - for (const auto& output : outputs) { + for (const auto &output : outputs) { if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { MS_LOG(DEBUG) << "DecodeInputOutput Failed"; return false; @@ -156,8 +156,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI return true; } -bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, - const std::shared_ptr& op_info) { +bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); bool ret = true; try { @@ -175,34 +175,34 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, op_attr->set_default_value(obj.at(kDefaultValue)); } op_info->add_attrs_ptr(op_attr); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what(); ret = false; } return ret; } -bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, +bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, size_t index) { bool ret = true; try { std::vector dtype; std::vector format; - for (const auto& it : dtype_format) { + for (const auto &it : dtype_format) { dtype.emplace_back(it[index][0]); format.emplace_back(it[index][1]); } op_io->set_dtypes(dtype); op_io->set_formats(format); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); ret = false; } return ret; } -bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr& op_info, const nlohmann::json& dtype_format) { +bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format) { bool ret = true; try { std::shared_ptr op_io = std::make_shared(); @@ -243,14 +243,14 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply } else if (io_type == kOutput) { op_info->add_outputs_ptr(op_io); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what(); ret = false; } return ret; } -std::shared_ptr OpLib::FindOp(const std::string& op_name, OpImplyType imply_type) { +std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); bool is_gpu = (context->device_target() == kGPUDevice); @@ -260,7 +260,7 @@ std::shared_ptr OpLib::FindOp(const std::string& op_name, OpImplyType im << ", current op num:" << op_info_.size(); return nullptr; } - for (const auto& op_info : op_info_) { + for (const auto &op_info : op_info_) { MS_EXCEPTION_IF_NULL(op_info); if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { return op_info; @@ -271,14 +271,14 @@ std::shared_ptr OpLib::FindOp(const std::string& op_name, OpImplyType im return nullptr; } -bool OpLib::GetRefInfo(const std::shared_ptr& op_info) { +bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); - const auto& output_infos = op_info->outputs_ptr(); - const auto& input_infos = op_info->inputs_ptr(); + const auto &output_infos = op_info->outputs_ptr(); + const auto &input_infos = op_info->inputs_ptr(); for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { - const auto& out_name = output_infos[out_index]->name(); + const auto &out_name = output_infos[out_index]->name(); for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { - const auto& in_name = input_infos[in_index]->name(); + const auto &in_name = input_infos[in_index]->name(); if (out_name == in_name) { if (op_info->has_ref_index(out_index)) { MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info"; @@ -293,9 +293,9 @@ bool OpLib::GetRefInfo(const std::shared_ptr& op_info) { return true; } -bool OpLib::CheckRepetition(const std::shared_ptr& op_info) { +bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); - for (const auto& exist_op_info : op_info_) { + for (const auto &exist_op_info : op_info_) { MS_EXCEPTION_IF_NULL(exist_op_info); if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && exist_op_info->impl_path() != op_info->impl_path()) { diff --git a/mindspore/ccsrc/kernel/oplib/oplib.h b/mindspore/ccsrc/kernel/oplib/oplib.h index 0e11e28d580..3d4dcad908e 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.h +++ b/mindspore/ccsrc/kernel/oplib/oplib.h @@ -28,23 +28,23 @@ class OpLib { public: OpLib() = default; virtual ~OpLib() = default; - bool RegOp(const std::string& json_string, const std::string& impl_path); - static std::shared_ptr FindOp(const std::string& op_name, OpImplyType imply_type); + bool RegOp(const std::string &json_string, const std::string &impl_path); + static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); protected: static std::vector> op_info_; private: - static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); - static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, - const std::shared_ptr& op_info); - static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, + static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); + static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info); + static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, size_t index); - static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr& op_info); - static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr& op_info, const nlohmann::json& dtype_format); - static bool GetRefInfo(const std::shared_ptr& op_info); - static bool CheckRepetition(const std::shared_ptr& op_info); + static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); + static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format); + static bool GetRefInfo(const std::shared_ptr &op_info); + static bool CheckRepetition(const std::shared_ptr &op_info); }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/mindspore.cc b/mindspore/ccsrc/mindspore.cc index 542814016f5..c98f67b51e5 100644 --- a/mindspore/ccsrc/mindspore.cc +++ b/mindspore/ccsrc/mindspore.cc @@ -19,6 +19,6 @@ namespace mindspore { // cppcheck-suppress unusedFunction -std::string set_version(const std::string& version) { return version; } +std::string set_version(const std::string &version) { return version; } } // namespace mindspore diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc index 80661a45393..772986d7141 100644 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ b/mindspore/ccsrc/onnx/onnx_exporter.cc @@ -42,11 +42,11 @@ struct OpMergedInfo { }; using GenAttrFuncType = - std::function; + std::function; template -void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { auto casted_value = dyn_cast(value); if (casted_value == nullptr) { MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; @@ -76,8 +76,8 @@ void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeTy } template -void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { auto tuple_ptr = dyn_cast(value); if (tuple_ptr == nullptr) { MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; @@ -99,8 +99,8 @@ void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_Attrib attr_proto->set_type(attr_type); } -void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); auto attr_value = GetValue(value); if (attr_value == "VALID") { @@ -112,16 +112,16 @@ void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType class OpAttrInfo { public: - OpAttrInfo(const std::string& attr_name, const string& onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) + OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) : attr_name_(attr_name), onnx_attr_name_(onnx_attr_name), onnx_attr_type_(onnx_attr_type), fn_gen_attr_(fn_gen_attr) {} ~OpAttrInfo() {} - const std::string& attr_name() const { return attr_name_; } - const std::string& onnx_attr_name() const { return onnx_attr_name_; } + const std::string &attr_name() const { return attr_name_; } + const std::string &onnx_attr_name() const { return onnx_attr_name_; } onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } @@ -134,27 +134,27 @@ class OpAttrInfo { class OpNameInfo { public: - OpNameInfo& set_op_type(const std::string& op_type) { + OpNameInfo &set_op_type(const std::string &op_type) { op_type_ = op_type; return *this; } - const std::string& op_type() const { return op_type_; } + const std::string &op_type() const { return op_type_; } - OpNameInfo& set_onnx_type(const std::string& onnx_type) { + OpNameInfo &set_onnx_type(const std::string &onnx_type) { onnx_type_ = onnx_type; return *this; } - const std::string& onnx_type() const { return onnx_type_; } + const std::string &onnx_type() const { return onnx_type_; } - OpNameInfo& Attr(const std::string& attr_name, const std::string& onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) { + OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); return *this; } - const std::vector& op_attrs() const { return op_attrs_; } + const std::vector &op_attrs() const { return op_attrs_; } private: std::string op_type_; // operator type of MindSpore @@ -183,8 +183,8 @@ OPERATOR_ONNX_CONVERT_DEFINE( .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto) .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, - [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto* const attr_proto, - const PrimitivePtr& prim) { + [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, + const PrimitivePtr &prim) { attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); auto attr_value = GetValue(value); if (attr_value == "valid") { @@ -220,7 +220,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, SetAttrValueToProto) .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, [](ValuePtr, onnx::AttributeProto_AttributeType, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); attr_proto->set_i(0); })) @@ -242,7 +242,7 @@ OPERATOR_ONNX_CONVERT_DEFINE( #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name -void RegisterOpConverters(const std::function& fn) { +void RegisterOpConverters(const std::function &fn) { fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); fn(OP_CONVERT_FUNCTION_NAME(Mul)()); @@ -265,16 +265,16 @@ class OpConvertRegistry { public: ~OpConvertRegistry() { Clear(); } - static void RegisterOneOpConverter(OpNameInfo&& op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } + static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } - static OpConvertRegistry& GetSingleton() { + static OpConvertRegistry &GetSingleton() { static OpConvertRegistry registry = OpConvertRegistry(); return registry; } - static const std::unordered_map& GetOpConvertMap() { return GetSingleton().op_map_; } + static const std::unordered_map &GetOpConvertMap() { return GetSingleton().op_map_; } void Clear() noexcept { op_map_.clear(); } @@ -289,59 +289,59 @@ class OnnxExporter { OnnxExporter() {} ~OnnxExporter() {} - std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); + std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); private: void InitModelInfo(); - void ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); - void ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); + void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); - size_t ExportPrimitive(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - const PrimitivePtr& prim, const std::vector& inputs, - onnx::GraphProto* graph_proto); + size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *graph_proto); static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); - void SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* value_proto, bool is_output = false); - void SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* tensor_proto); + void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); + void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); - void MatchAndMark(const FuncGraphPtr& func_graph, const std::vector& nodes, - std::unordered_map* op_merged_infos_ptr); - void ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); + void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr); + void ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); - void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); - void ExportPrimReshape(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); - void ExportPrimReduceMean(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); - void ExportPrimCast(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportPrimPReLU(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); + void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); - void ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); + void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportOutput(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - std::string GetNodeInputName(const AnfNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto); + void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + std::string GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto); - void ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* tensor_proto); - void SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* node_proto); + void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); + void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); size_t AllocateNodeIndex() { return ++onnx_node_index_; } void ResetNodeIndex() { onnx_node_index_ = 0; } - static int GetInt32Value(const AnfNodePtr& node) { + static int GetInt32Value(const AnfNodePtr &node) { auto value_node_ptr = dyn_cast(node); MS_EXCEPTION_IF_NULL(value_node_ptr); return GetValue(value_node_ptr->value()); @@ -352,7 +352,7 @@ class OnnxExporter { size_t onnx_node_index_ = 0; }; -std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { +std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return ""; } @@ -360,7 +360,7 @@ std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { OpConvertRegistry::GetSingleton().Clear(); OpConvertRegistry::RegisterAllOpConverters(); InitModelInfo(); - onnx::GraphProto* graph_proto = model_.mutable_graph(); + onnx::GraphProto *graph_proto = model_.mutable_graph(); ExportFuncGraph(func_graph, graph_proto); return model_.SerializeAsString(); } @@ -369,11 +369,11 @@ void OnnxExporter::InitModelInfo() { model_.set_ir_version(onnx::IR_VERSION_2019_1_22); model_.set_producer_name("MindSpore"); model_.set_producer_version("1.0"); - onnx::OperatorSetIdProto* opset_proto = model_.add_opset_import(); + onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); opset_proto->set_version(9); } -void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { std::map node_map; onnx_node_index_ = func_graph->parameters().size(); @@ -390,14 +390,14 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphPr ExportNodes(func_graph, &node_map, graph_proto); } -void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { - for (auto& param : func_graph->parameters()) { +void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + for (auto ¶m : func_graph->parameters()) { const ParameterPtr param_ptr = dyn_cast(param); if (param_ptr == nullptr) { MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; } - onnx::ValueInfoProto* input_proto = graph_proto->add_input(); + onnx::ValueInfoProto *input_proto = graph_proto->add_input(); input_proto->set_name(param_ptr->ToString()); SetValueInfoType(param_ptr, input_proto); @@ -405,7 +405,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphP continue; } // parameter with default value is an ONNX initializer - onnx::TensorProto* initializer_proto = graph_proto->add_initializer(); + onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); initializer_proto->set_name(param_ptr->ToString()); SetTensorProtoInfo(param_ptr, initializer_proto); // set value for initializer @@ -445,25 +445,25 @@ onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { return iter->second; } -void OnnxExporter::SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* const value_proto, bool is_output) { +void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { auto dtype = node->Type(); auto shape = node->Shape(); - onnx::TypeProto* type_proto = value_proto->mutable_type(); + onnx::TypeProto *type_proto = value_proto->mutable_type(); if (dtype->isa() && shape->isa()) { auto tensor = dyn_cast(dtype); auto elem_type = tensor->element(); - const auto& dims = dyn_cast(shape)->shape(); + const auto &dims = dyn_cast(shape)->shape(); // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); type_proto->mutable_tensor_type()->set_elem_type(type); - for (const auto& dim : dims) { + for (const auto &dim : dims) { type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); } } } -void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* const tensor_proto) { +void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { auto dtype = param->Type(); auto shape = param->Shape(); if (!dtype->isa() || !shape->isa()) { @@ -472,18 +472,18 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorPro auto tensor = dyn_cast(dtype); auto elem_type = tensor->element(); - const auto& dims = dyn_cast(shape)->shape(); + const auto &dims = dyn_cast(shape)->shape(); tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); - for (const auto& dim : dims) { + for (const auto &dim : dims) { tensor_proto->add_dims(dim); } } -void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vector& nodes, - std::unordered_map* op_merged_infos_ptr) { - std::unordered_map& op_merged_infos = *op_merged_infos_ptr; +void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr) { + std::unordered_map &op_merged_infos = *op_merged_infos_ptr; - for (auto& node : nodes) { + for (auto &node : nodes) { if (!node->isa()) { continue; } @@ -492,7 +492,7 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto // if the key `input` does not exist, just create a new one op_merged_infos[cnode].referred_count += 1; } - for (auto& input : cnode->inputs()) { + for (auto &input : cnode->inputs()) { if (!input->isa()) { continue; } @@ -527,14 +527,14 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto * | +-- Parameter * | `-- ValueNode */ -void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::unordered_map op_merged_infos; MatchAndMark(func_graph, nodes, &op_merged_infos); - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { if (!node->isa()) { continue; } @@ -570,20 +570,20 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_shape = node->input(2); std::string name_shape; if (input_shape->isa()) { auto const_node_idx = AllocateNodeIndex(); (*node_map_ptr)[input_shape] = const_node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); name_shape = std::to_string(const_node_idx); node_proto->add_output(name_shape); node_proto->set_op_type("Constant"); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); @@ -595,28 +595,28 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const C auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimReshape->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(name_x); node_proto->add_input(name_shape); } -void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_axis = node->input(2); auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimReduceMean->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_data); if (input_axis->isa()) { - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("axes"); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); auto axis_value = dyn_cast(input_axis)->value(); @@ -630,20 +630,20 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, cons } } -void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_type = node->input(2); auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimCast->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_data); if (input_type->isa()) { - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("to"); attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); auto type_value = dyn_cast(input_type)->value(); @@ -655,8 +655,8 @@ void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNod } } -void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); @@ -668,11 +668,11 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { auto node_idx = AllocateNodeIndex(); - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type("Unsqueeze"); node_proto->add_output(std::to_string(node_idx)); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); attr_proto->set_name("axes"); attr_proto->add_ints(1); @@ -684,15 +684,15 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type("PRelu"); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_x); node_proto->add_input(input_slope); } -void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert if (node->IsApply(prim::kPrimReshape)) { return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); @@ -735,31 +735,31 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& n (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); } -size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::map* node_map_ptr, - const PrimitivePtr& prim, const std::vector& inputs, - onnx::GraphProto* const graph_proto) { +size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *const graph_proto) { auto op_map = OpConvertRegistry::GetOpConvertMap(); auto op_iter = op_map.find(prim->name()); if (op_iter == op_map.end()) { MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; } - const OpNameInfo& op_convert_info = op_iter->second; + const OpNameInfo &op_convert_info = op_iter->second; auto node_idx = AllocateNodeIndex(); - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->add_output(std::to_string(node_idx)); node_proto->set_op_type(op_convert_info.onnx_type()); // Set inputs - for (const auto& input : inputs) { + for (const auto &input : inputs) { auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); node_proto->add_input(input_name); } // Set node attribute - for (const OpAttrInfo& attr : op_convert_info.op_attrs()) { - const std::string& attr_name = attr.attr_name(); + for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { + const std::string &attr_name = attr.attr_name(); ValuePtr attr_value = nullptr; if (!attr_name.empty()) { attr_value = prim->GetAttr(attr_name); @@ -767,15 +767,15 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::ma MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; } } - onnx::AttributeProto* onnx_attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); onnx_attr_proto->set_name(attr.onnx_attr_name()); attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); } return node_idx; } -void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto conv_node = dyn_cast(node->input(1)); auto input_x = conv_node->input(1); // conv input x auto input_w = conv_node->input(2); // conv weight(filter) @@ -786,8 +786,8 @@ void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePt (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); } -void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto matmul_node = dyn_cast(node->input(1)); auto input_x = matmul_node->input(1); // matmul input x auto input_y = matmul_node->input(2); // matmul input y @@ -798,9 +798,9 @@ void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePt (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); } -void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { auto batch_norm_node = dyn_cast(node->input(1)); PrimitivePtr prim_batch_norm = dyn_cast((dyn_cast(batch_norm_node->input(0)))->value()); @@ -811,20 +811,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CN (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); } -void OnnxExporter::ExportOutput(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { if (node->inputs().size() != 2) { MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; } AnfNodePtr arg = node->input(1); std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); - onnx::ValueInfoProto* output_proto = graph_proto->add_output(); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); output_proto->set_name(name); SetValueInfoType(arg, output_proto, false); } -std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { if (node->isa()) { auto iter = node_map_ptr->find(node); if (iter == node_map_ptr->end()) { @@ -848,7 +848,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::mapadd_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->add_output(node_name); SetNodeAttribute(node->cast()->value(), node_proto); @@ -859,7 +859,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::maptype_name(); } -void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* const tensor_proto) { +void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { auto tuple_ptr = dyn_cast(value); MS_EXCEPTION_IF_NULL(tuple_ptr); if (tuple_ptr->size() == 0) { @@ -891,14 +891,14 @@ void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto } } -void OnnxExporter::SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* const node_proto) { +void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { node_proto->set_op_type("Constant"); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; } -std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { OnnxExporter exporter; return exporter.GetOnnxProtoString(func_graph); } diff --git a/mindspore/ccsrc/operator/cc_implementations.cc b/mindspore/ccsrc/operator/cc_implementations.cc index 49dc3ab7910..2a3429ca522 100644 --- a/mindspore/ccsrc/operator/cc_implementations.cc +++ b/mindspore/ccsrc/operator/cc_implementations.cc @@ -32,12 +32,12 @@ enum class DataType { kInt, kFloat, kDouble, kUnknown }; // Whether has a T type data in AnyPtrList. template -bool HasType(const AnyPtrList& list) { - bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr& ptr) { return ptr->is(); }); +bool HasType(const AnyPtrList &list) { + bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is(); }); return ret; } -DataType InferType(const AnyPtrList& list) { +DataType InferType(const AnyPtrList &list) { if (HasType(list)) { return DataType::kDouble; } else if (HasType(list)) { @@ -180,7 +180,7 @@ bool InnerScalarGe(T x, U y) { } #define SCALAR_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList& list) { \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ do { \ if (list.size() < 2) { \ MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ @@ -223,7 +223,7 @@ SCALAR_OP(Pow) SCALAR_OP(Floordiv) #define LOGIC_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList& list) { \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ if (list.size() < 2) { \ MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ } \ @@ -274,7 +274,7 @@ LOGIC_OP(Ne) LOGIC_OP(Le) LOGIC_OP(Ge) -ValuePtr ScalarUAdd(const ValuePtrList& list) { +ValuePtr ScalarUAdd(const ValuePtrList &list) { if (list.size() != 1) { MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); } @@ -283,7 +283,7 @@ ValuePtr ScalarUAdd(const ValuePtrList& list) { return x; } -ValuePtr ScalarUSub(const ValuePtrList& list) { +ValuePtr ScalarUSub(const ValuePtrList &list) { if (list.size() != 1) { MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); } @@ -302,7 +302,7 @@ ValuePtr ScalarUSub(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; } -ValuePtr ScalarLog(const ValuePtrList& list) { +ValuePtr ScalarLog(const ValuePtrList &list) { if (list.empty()) { MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; } @@ -321,7 +321,7 @@ ValuePtr ScalarLog(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); } -ValuePtr BoolNot(const ValuePtrList& list) { +ValuePtr BoolNot(const ValuePtrList &list) { if (list.empty()) { MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; } @@ -337,7 +337,7 @@ ValuePtr BoolNot(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); } -ValuePtr BoolAnd(const ValuePtrList& list) { +ValuePtr BoolAnd(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; } @@ -356,7 +356,7 @@ ValuePtr BoolAnd(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; } -ValuePtr BoolOr(const ValuePtrList& list) { +ValuePtr BoolOr(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; } @@ -375,7 +375,7 @@ ValuePtr BoolOr(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; } -ValuePtr BoolEq(const ValuePtrList& list) { +ValuePtr BoolEq(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; } diff --git a/mindspore/ccsrc/operator/cc_implementations.h b/mindspore/ccsrc/operator/cc_implementations.h index 69981cea7da..cef34da4f4a 100644 --- a/mindspore/ccsrc/operator/cc_implementations.h +++ b/mindspore/ccsrc/operator/cc_implementations.h @@ -29,29 +29,29 @@ namespace prim { using Any = mindspore::Any; using AnyPtrList = std::vector>; using ValuePtrList = std::vector; -using OpsFunction = std::function; -using AnfNodeOpsFunction = std::function&)>; +using OpsFunction = std::function; +using AnfNodeOpsFunction = std::function &)>; -ValuePtr ScalarAdd(const ValuePtrList& list); -ValuePtr ScalarSub(const ValuePtrList& list); -ValuePtr ScalarMul(const ValuePtrList& list); -ValuePtr ScalarDiv(const ValuePtrList& list); -ValuePtr ScalarMod(const ValuePtrList& list); -ValuePtr ScalarPow(const ValuePtrList& list); -ValuePtr ScalarFloordiv(const ValuePtrList& list); -ValuePtr ScalarUAdd(const ValuePtrList& list); -ValuePtr ScalarUSub(const ValuePtrList& list); -ValuePtr ScalarLog(const ValuePtrList& list); -ValuePtr ScalarEq(const ValuePtrList& list); -ValuePtr ScalarLt(const ValuePtrList& list); -ValuePtr ScalarGt(const ValuePtrList& list); -ValuePtr ScalarNe(const ValuePtrList& list); -ValuePtr ScalarLe(const ValuePtrList& list); -ValuePtr ScalarGe(const ValuePtrList& list); -ValuePtr BoolNot(const ValuePtrList& list); -ValuePtr BoolAnd(const ValuePtrList& list); -ValuePtr BoolOr(const ValuePtrList& list); -ValuePtr BoolEq(const ValuePtrList& list); +ValuePtr ScalarAdd(const ValuePtrList &list); +ValuePtr ScalarSub(const ValuePtrList &list); +ValuePtr ScalarMul(const ValuePtrList &list); +ValuePtr ScalarDiv(const ValuePtrList &list); +ValuePtr ScalarMod(const ValuePtrList &list); +ValuePtr ScalarPow(const ValuePtrList &list); +ValuePtr ScalarFloordiv(const ValuePtrList &list); +ValuePtr ScalarUAdd(const ValuePtrList &list); +ValuePtr ScalarUSub(const ValuePtrList &list); +ValuePtr ScalarLog(const ValuePtrList &list); +ValuePtr ScalarEq(const ValuePtrList &list); +ValuePtr ScalarLt(const ValuePtrList &list); +ValuePtr ScalarGt(const ValuePtrList &list); +ValuePtr ScalarNe(const ValuePtrList &list); +ValuePtr ScalarLe(const ValuePtrList &list); +ValuePtr ScalarGe(const ValuePtrList &list); +ValuePtr BoolNot(const ValuePtrList &list); +ValuePtr BoolAnd(const ValuePtrList &list); +ValuePtr BoolOr(const ValuePtrList &list); +ValuePtr BoolEq(const ValuePtrList &list); std::vector BroadcastShape_(std::vector s1, std::vector s2); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index 9a665e8a30d..bf0dcf37d42 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -66,7 +66,7 @@ const MetaFuncGraphPtr kTail = std::make_shared("tail"); // Apply a function of two arguments cumulatively to the items of a sequence, // from left to right, so as to reduce the sequence to a single value.For example, // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). -AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { +AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { std::shared_ptr ret; size_t size = list.size(); if (size < 2) { @@ -88,7 +88,7 @@ AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { return ret; } -AnfNodePtr Reduce(const AnfNodeOpsFunction& func, const std::vector& list) { +AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector &list) { size_t size = list.size(); if (size < 2) { MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; @@ -121,7 +121,7 @@ void HyperMap::Init() { {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); } -HyperMap::HyperMap(const std::shared_ptr& fn_leaf) +HyperMap::HyperMap(const std::shared_ptr &fn_leaf) : MetaFuncGraph("hyper_map"), fn_leaf_(fn_leaf), broadcast_(false), @@ -129,13 +129,13 @@ HyperMap::HyperMap(const std::shared_ptr& fn_leaf) Init(); } -HyperMap::HyperMap(const HyperMap& h) +HyperMap::HyperMap(const HyperMap &h) : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { Init(); } -AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); std::vector inputs; if (fn_arg != nullptr) { @@ -145,17 +145,17 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const Anf } (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), - [](const std::pair& item) { return item.first; }); + [](const std::pair &item) { return item.first; }); return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { auto lhs = std::static_pointer_cast(item.second); MS_EXCEPTION_IF_NULL(lhs); return lhs->elements().size() != size; @@ -179,7 +179,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraph (void)std::transform( arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair& item) { + [&func_graph, i](const std::pair &item) { return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); }); @@ -188,13 +188,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraph return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { auto lhs = std::static_pointer_cast(item.second); MS_EXCEPTION_IF_NULL(lhs); return lhs->elements().size() != size; @@ -226,8 +226,8 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGrap return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(func_graph); @@ -257,11 +257,11 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGrap return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { bool found = false; TypeId id = kObjectTypeEnd; std::pair pair; - for (auto& item : arg_map) { + for (auto &item : arg_map) { pair = item; id = item.second->type_id(); if (nonleaf_.count(id)) { @@ -272,7 +272,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a if (found) { // In a nonleaf situation, all arguments must have the same generic. - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair &item) { if (item.first != pair.first) { return item.second->type_id() != pair.second->type_id(); } @@ -283,7 +283,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; int idx = 0; - for (auto& item : arg_map) { + for (auto &item : arg_map) { oss << ++idx << ": " << item.second->ToString() << "\n"; } MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); @@ -308,14 +308,14 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a } } -ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairList& args_spec_list) { +ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { TypePtr type_tensor = std::make_shared(); bool flag = std::any_of( args_spec_list.begin(), args_spec_list.end(), - [type_tensor](const std::pair& item) { return IsSubType(item.second, type_tensor); }); + [type_tensor](const std::pair &item) { return IsSubType(item.second, type_tensor); }); if (flag && broadcast_) { ArgsPairList ret; - for (auto& item : args_spec_list) { + for (auto &item : args_spec_list) { if (!IsSubType(item.second, type_tensor)) { TypePtr type_tensor_ele = std::make_shared(item.second); ret.push_back( @@ -329,7 +329,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairL return args_spec_list; } -FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { +FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr ptrGraph = std::make_shared(); ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ptrGraph->debug_info()->set_name("hyper_map"); @@ -353,7 +353,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { return ptrGraph; } -abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& args_spec_list) const { +abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { if (fn_leaf_ == nullptr) { MS_EXCEPTION_IF_NULL(args_spec_list[0]); // Assert that hypermap's function param does not contain free variables @@ -368,20 +368,20 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& AbstractBasePtrList broadened; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), - [](const AbstractBasePtr& arg) -> AbstractBasePtr { + [](const AbstractBasePtr &arg) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(arg); return arg->Broaden(); }); return broadened; } -REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { (void)py::class_>(*m, "HyperMap_") .def(py::init>(), py::arg("leaf")) .def(py::init<>()); })); -FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple) { +FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { MS_EXCEPTION_IF_NULL(a_tuple); FuncGraphPtr ret = std::make_shared(); @@ -401,7 +401,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tu return ret; } -FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list) { +FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { MS_EXCEPTION_IF_NULL(a_list); FuncGraphPtr ret = std::make_shared(); @@ -421,7 +421,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list return ret; } -FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { if (args_spec_list.size() != 1) { MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; } @@ -441,11 +441,11 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) } REGISTER_PYBIND_DEFINE( - Tail_, ([](const py::module* m) { - (void)py::class_>(*m, "Tail_").def(py::init()); + Tail_, ([](const py::module *m) { + (void)py::class_>(*m, "Tail_").def(py::init()); })); -FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { int tuple_size = SizeToInt(args_spec_list.size()); std::ostringstream ss; @@ -486,7 +486,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& arg return fg; } -GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_list, bool sens_param) +GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { if (get_by_list) { signatures_ = @@ -496,8 +496,8 @@ GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_ } } -FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, - const std::vector& params_list, bool applyJ) { +FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, + const std::vector ¶ms_list, bool applyJ) { FuncGraphPtr ret = std::make_shared(); ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); @@ -537,7 +537,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, return ret; } -void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, +void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, ValueNodePtr opsTupleItem) { MS_EXCEPTION_IF_NULL(func_graph); @@ -590,7 +590,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, An } // Generate the graph. -FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { if (args_spec_list.size() < 1) { MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " << args_spec_list.size() << "."; @@ -637,21 +637,21 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_sp return dfBuilder; } -REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { (void)py::class_>( *m, "GradOperation_") - .def(py::init(), py::arg("fn")) - .def(py::init(), py::arg("fn"), py::arg("get_all"), + .def(py::init(), py::arg("fn")) + .def(py::init(), py::arg("fn"), py::arg("get_all"), py::arg("get_by_list"), py::arg("sens_param")); })); -MultitypeFuncGraph::MultitypeFuncGraph(const std::string& name) : MetaFuncGraph(name) { +MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { fn_cache_.clear(); signatures_ = std::vector({// def multitype(*args:ref): {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); } -void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) { +void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; auto fn = fn_cache_.find(types); if (fn != fn_cache_.end()) { @@ -660,7 +660,7 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) fn_cache_[types] = s_fn; } -void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& py_fn) { +void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; auto fn = fn_cache_.find(types); if (fn != fn_cache_.end()) { @@ -669,9 +669,9 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& fn_cache_py_[types] = py_fn; } -void MultitypeFuncGraph::Register(const std::vector& types_name, const py::function& py_fn) { +void MultitypeFuncGraph::Register(const std::vector &types_name, const py::function &py_fn) { TypePtrList types; - for (auto& type_name : types_name) { + for (auto &type_name : types_name) { auto type_ptr = StringToType(type_name); if (type_ptr == nullptr) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error "; @@ -681,7 +681,7 @@ void MultitypeFuncGraph::Register(const std::vector& types_name, co Register(types, py_fn); } -void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& py_fn) { +void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { std::vector types_name; for (size_t it = 0; it < tuple.size(); ++it) { py::object name_py = tuple[it]; @@ -693,16 +693,16 @@ void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& } Register(types_name, py_fn); } -static TypePtr UnwrapRef(const TypePtr& type) { +static TypePtr UnwrapRef(const TypePtr &type) { if (type->isa()) { return type->cast()->subtype(); } return type; } -FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { +FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { bool find_fn = false; py::function py_fn; - for (auto& item : fn_cache_py_) { + for (auto &item : fn_cache_py_) { TypePtrList sign = item.first; if (sign.size() != types.size()) { continue; @@ -735,7 +735,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ << "`, corresponding location info:\n"; int idx = 0; - for (auto& item : fn_cache_py_) { + for (auto &item : fn_cache_py_) { FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); if (func_graph == nullptr) { MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; @@ -747,15 +747,15 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { << oss.str(); } -REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { (void)py::class_>( *m, "MultitypeFuncGraph_") - .def(py::init()) + .def(py::init()) .def("register_fn", &MultitypeFuncGraph::PyRegister); })); // Generate the ListMap func graph. -FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { size_t args_num = args_spec_list.size(); // args: fn, list1, list2, ... if (args_num < 2) { @@ -821,8 +821,8 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_lis return fg_ptr; } -void ListMap::MakeCond(const std::vector& lists, const FuncGraphPtr& fgnext_ptr, - const FuncGraphPtr& fg_ptr) { +void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr &fgnext_ptr, + const FuncGraphPtr &fg_ptr) { MS_EXCEPTION_IF_NULL(fg_ptr); AnfNodePtr fn = fg_ptr->add_parameter(); @@ -858,8 +858,8 @@ void ListMap::MakeCond(const std::vector& lists, const FuncGraphPtr& fgtrue_ptr->set_output(output_cnode); } -void ListMap::MakeNext(const std::vector& lists, const FuncGraphPtr& fgcond_ptr, - const FuncGraphPtr& fg_ptr) { +void ListMap::MakeNext(const std::vector &lists, const FuncGraphPtr &fgcond_ptr, + const FuncGraphPtr &fg_ptr) { MS_EXCEPTION_IF_NULL(fg_ptr); AnfNodePtr fn = fg_ptr->add_parameter(); @@ -893,7 +893,7 @@ void ListMap::MakeNext(const std::vector& lists, const FuncGraphPtr& fg_ptr->set_output(output_cnode); } -FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // args: tuple1, tuple2 abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); AbstractBasePtr abs_a = args_spec_list[0]; @@ -928,7 +928,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_li return ret; } -int GetArgScalarValue(const abstract::AbstractScalarPtr& scalar, const std::string&) { +int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { MS_EXCEPTION_IF_NULL(scalar); return GetValue(scalar->BuildValue()); } @@ -942,7 +942,7 @@ int GetPositiveIndex(int index, int length) { return index; } -int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std::string& member_name) { +int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { MS_EXCEPTION_IF_NULL(member); if (member->isa()) { @@ -957,8 +957,8 @@ int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std << member->ToString(); } -void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSlicePtr& slice, int* start_index, - int* stop_index, int* step_value) { +void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, + int *stop_index, int *step_value) { MS_EXCEPTION_IF_NULL(tuple); MS_EXCEPTION_IF_NULL(slice); MS_EXCEPTION_IF_NULL(start_index); @@ -998,7 +998,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSl } } -FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tuple // args: tuple, start index, end index, step const std::string op_name("TupleSlice"); @@ -1032,7 +1032,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ return ret; } -int ConvertBinaryToDecimal(const std::vector& number_bin) { +int ConvertBinaryToDecimal(const std::vector &number_bin) { unsigned int number_dec = 0; for (size_t index = 0; index < number_bin.size(); index++) { number_dec |= number_bin[index] << index; @@ -1040,8 +1040,8 @@ int ConvertBinaryToDecimal(const std::vector& number_bin) { return static_cast(number_dec); } -void ParseSlice(const AbstractSlicePtr& slice, std::vector* begin, std::vector* end, - std::vector* strides, int length) { +void ParseSlice(const AbstractSlicePtr &slice, std::vector *begin, std::vector *end, + std::vector *strides, int length) { MS_EXCEPTION_IF_NULL(slice); MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); @@ -1064,8 +1064,8 @@ void ParseSlice(const AbstractSlicePtr& slice, std::vector* begin, std::vec strides->push_back(step_value); } -int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, const std::vector& shape, - std::vector* begin, std::vector* end, std::vector* strides) { +int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector &shape, + std::vector *begin, std::vector *end, std::vector *strides) { MS_EXCEPTION_IF_NULL(slice_tuple); MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); @@ -1111,8 +1111,8 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, return ConvertBinaryToDecimal(shrink); } -int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const std::vector& shape, - std::vector* begin, std::vector* end, std::vector* strides) { +int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector &shape, + std::vector *begin, std::vector *end, std::vector *strides) { MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(strides); @@ -1132,9 +1132,9 @@ int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const return 0; } -int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, const std::vector& shape, - std::vector* begin, std::vector* end, - std::vector* strides) { +int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector &shape, + std::vector *begin, std::vector *end, + std::vector *strides) { MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(strides); @@ -1153,7 +1153,7 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, co return 1; } -FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tensor // args: tensor, slice or slice tuple const std::string op_name = std::string("TensorSlice"); @@ -1177,7 +1177,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); } else { std::ostringstream args_info; - for (const auto& arg : args_spec_list) { + for (const auto &arg : args_spec_list) { MS_EXCEPTION_IF_NULL(arg); args_info << arg->ToString() << "\n"; } @@ -1199,19 +1199,19 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec return ret_graph; } -REGISTER_PYBIND_DEFINE( - TupleAdd_, ([](const py::module* m) { - (void)py::class_>(*m, "TupleAdd_").def(py::init()); - })); - -REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module* m) { - (void)py::class_>(*m, "TupleSlice_") - .def(py::init()); +REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { + (void)py::class_>(*m, "TupleAdd_") + .def(py::init()); })); -REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { + (void)py::class_>(*m, "TupleSlice_") + .def(py::init()); + })); + +REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { (void)py::class_>(*m, "TensorSlice_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h index dc8627ba615..1dad2e08cf4 100644 --- a/mindspore/ccsrc/operator/composite/composite.h +++ b/mindspore/ccsrc/operator/composite/composite.h @@ -47,20 +47,20 @@ using ArgsPairList = std::vector>; class MultitypeFuncGraph : public MetaFuncGraph { public: - explicit MultitypeFuncGraph(const std::string& name); + explicit MultitypeFuncGraph(const std::string &name); ~MultitypeFuncGraph() override = default; MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) - using specialize_fn = FuncGraph* (*)(TypePtrList); + using specialize_fn = FuncGraph *(*)(TypePtrList); // Register a method which specialize based on types vectors; - virtual void Register(const TypePtrList& types, specialize_fn s_fn); - virtual void Register(const TypePtrList& types, const py::function& py_fn); - virtual void Register(const std::vector& types_name, const py::function& py_fn); - virtual void PyRegister(const py::tuple& tuple, const py::function& py_fn); + virtual void Register(const TypePtrList &types, specialize_fn s_fn); + virtual void Register(const TypePtrList &types, const py::function &py_fn); + virtual void Register(const std::vector &types_name, const py::function &py_fn); + virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); - FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } - const std::unordered_map& GetPyFunctions() const { + const std::unordered_map &GetPyFunctions() const { return fn_cache_py_; } @@ -72,10 +72,10 @@ using MultitypeFuncGraphPtr = std::shared_ptr; class HyperMap : public MetaFuncGraph { public: - explicit HyperMap(const std::shared_ptr& fn_leaf = nullptr); - HyperMap(const HyperMap& h); + explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); + HyperMap(const HyperMap &h); void Init(); - HyperMap& operator=(const HyperMap& h) { + HyperMap &operator=(const HyperMap &h) { if (this != &h) { fn_leaf_ = h.fn_leaf_; broadcast_ = h.broadcast_; @@ -89,21 +89,21 @@ class HyperMap : public MetaFuncGraph { ~HyperMap() override = default; MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) - abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const override; - FuncGraphPtr GenerateFromTypes(const TypePtrList& args_spec_list) override; + abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } private: - AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr Make(const FuncGraphPtr& graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map); - ArgsPairList Harmonize(const FuncGraphPtr& graph, const ArgsPairList& args_spec_list); + AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); + ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); MultitypeFuncGraphPtr fn_leaf_; bool broadcast_; @@ -113,7 +113,7 @@ using HyperMapPtr = std::shared_ptr; class HyperMapPy : public HyperMap { public: - explicit HyperMapPy(const std::shared_ptr& fn_leaf = nullptr) : HyperMap(fn_leaf) {} + explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} ~HyperMapPy() override = default; MS_DECLARE_PARENT(HyperMapPy, HyperMap) }; @@ -123,56 +123,56 @@ extern ValuePtr kCompositeHyperMap; class Tail : public MetaFuncGraph { public: - explicit Tail(const std::string& name) : MetaFuncGraph(name) {} + explicit Tail(const std::string &name) : MetaFuncGraph(name) {} ~Tail() override = default; MS_DECLARE_PARENT(Tail, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple); - FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr& a_list); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); + FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); - friend bool operator==(const Tail& lhs, const Tail& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } }; using TailPtr = std::shared_ptr; class MakeTupleGradient : public MetaFuncGraph { public: - explicit MakeTupleGradient(const std::string& name) : MetaFuncGraph(name) {} + explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} ~MakeTupleGradient() override = default; MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const MakeTupleGradient& lhs, const MakeTupleGradient& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } }; using MakeTupleGradientPtr = std::shared_ptr; class GradOperation : public MetaFuncGraph { public: - explicit GradOperation(const std::string& name, bool get_all = false, bool get_by_list = false, + explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, bool sens_param = false); ~GradOperation() override = default; MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) - FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector& ptrParams, + FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector &ptrParams, bool applyJ = false); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; bool sens_param() const { return sens_param_; } bool get_all_; bool get_by_list_; bool sens_param_; private: - void doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, + void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, ValueNodePtr opsTupleItem); }; using GradOperationPtr = std::shared_ptr; class ListMap { public: - explicit ListMap(const std::string& name) : name_(name) { cache_.clear(); } + explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } ~ListMap() = default; - void MakeCond(const std::vector& lists, const FuncGraphPtr& gnext_ptr, const FuncGraphPtr& graph_ptr); - void MakeNext(const std::vector& lists, const FuncGraphPtr& gcond_ptr, const FuncGraphPtr& graph_ptr); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list); + void MakeCond(const std::vector &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); + void MakeNext(const std::vector &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); private: std::string name_; @@ -181,31 +181,31 @@ class ListMap { class TupleAdd : public MetaFuncGraph { public: - explicit TupleAdd(const std::string& name) : MetaFuncGraph(name) {} + explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} ~TupleAdd() override = default; MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TupleAdd& lhs, const TupleAdd& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } }; using TupleAddPtr = std::shared_ptr; class TupleSlice : public MetaFuncGraph { public: - explicit TupleSlice(const std::string& name) : MetaFuncGraph(name) {} + explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} ~TupleSlice() override = default; MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TupleSlice& lhs, const TupleSlice& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } }; using TupleSlicePtr = std::shared_ptr; class TensorSlice : public MetaFuncGraph { public: - explicit TensorSlice(const std::string& name) : MetaFuncGraph(name) {} + explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {} ~TensorSlice() override = default; MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TensorSlice& lhs, const TensorSlice& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } }; using TensorSlicePtr = std::shared_ptr; diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index a4a26377f5e..95e38247d9a 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -34,7 +34,7 @@ namespace prim { namespace { using PatternListType = std::initializer_list; -const std::vector& GetSignature(const ValuePtr& function) { +const std::vector &GetSignature(const ValuePtr &function) { static const auto empty = std::vector(); if (function->isa()) { return function->cast()->signatures(); @@ -44,8 +44,8 @@ const std::vector& GetSignature(const ValuePtr& function) { return empty; } -void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& args_spec_list, - const std::vector& signature, bool has_var, std::vector* op_inputs) { +void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, + const std::vector &signature, bool has_var, std::vector *op_inputs) { std::size_t sig_size = signature.size(); auto positional_size = sig_size; if (has_var) { @@ -64,8 +64,8 @@ void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& arg } // Get the largest type of index in the same SignatureEnumDType of arguments. -std::map GetMaxDtypeIndex(const std::vector& dtypes, - const abstract::AbstractBasePtrList& args_spec_list) { +std::map GetMaxDtypeIndex(const std::vector &dtypes, + const abstract::AbstractBasePtrList &args_spec_list) { // record index for signature.dtypes of the same type // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} std::map> type_indexs; @@ -89,7 +89,7 @@ std::map GetMaxDtypeIndex(const std::vectorisa()) { arg_value = arg_value->cast()->ref(); @@ -104,7 +104,7 @@ std::map GetMaxDtypeIndex(const std::vector& signature, const abstract::AbstractBasePtrList& args_spec_list, - const FuncGraphPtr& graph, std::vector* op_inputs) { +void DoAutoCast(const std::vector &signature, const abstract::AbstractBasePtrList &args_spec_list, + const FuncGraphPtr &graph, std::vector *op_inputs) { std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature& sig) { return sig.dtype; }); + [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { return; @@ -143,10 +143,10 @@ void DoAutoCast(const std::vector& signature, const abstract::Abstrac } } -AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const std::vector& params_list) { +AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const std::vector ¶ms_list) { // args: original inputs - auto& signature = GetSignature(function); + auto &signature = GetSignature(function); std::size_t sig_size = signature.size(); auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); if (sig_size > 0) { @@ -196,13 +196,13 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func } } // namespace -AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs) { +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); return new_cnode; } -FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { FuncGraphPtr func_graph = std::make_shared(); for (size_t i = 0; i < args_spec_list.size(); ++i) { diff --git a/mindspore/ccsrc/operator/composite/do_signature.h b/mindspore/ccsrc/operator/composite/do_signature.h index b88053e2247..3e1596d63f4 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.h +++ b/mindspore/ccsrc/operator/composite/do_signature.h @@ -37,17 +37,17 @@ namespace mindspore { namespace prim { class DoSignatureMetaFuncGraph : public MetaFuncGraph { public: - explicit DoSignatureMetaFuncGraph(const std::string& name, const ValuePtr& function) + explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) : MetaFuncGraph("S-" + name), function_(function) {} ~DoSignatureMetaFuncGraph() override = default; MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) override; + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; const ValuePtr function() const { return function_; } - friend bool operator==(const DoSignatureMetaFuncGraph& lhs, const DoSignatureMetaFuncGraph& rhs) { + friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { return &lhs == &rhs; } @@ -56,8 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph { }; using RWSignaturePtr = std::shared_ptr; -AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs); +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.cc b/mindspore/ccsrc/operator/composite/list_append_operation.cc index 8621a8a8ba0..b5a4fc626e9 100644 --- a/mindspore/ccsrc/operator/composite/list_append_operation.cc +++ b/mindspore/ccsrc/operator/composite/list_append_operation.cc @@ -27,7 +27,7 @@ namespace mindspore { // namespace to support composite operators definition namespace prim { -FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& args_list) { +FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { abstract::CheckArgsSize("ListAppend", args_list, 2); AbstractBasePtr arg0 = args_list[0]; @@ -52,9 +52,9 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& return ret; } -REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { (void)py::class_>(*m, "ListAppend_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.h b/mindspore/ccsrc/operator/composite/list_append_operation.h index f34b6b864ec..1da3f9a0098 100644 --- a/mindspore/ccsrc/operator/composite/list_append_operation.h +++ b/mindspore/ccsrc/operator/composite/list_append_operation.h @@ -28,15 +28,15 @@ namespace mindspore { namespace prim { class ListAppend : public MetaFuncGraph { public: - explicit ListAppend(const std::string& name) : MetaFuncGraph(name) {} + explicit ListAppend(const std::string &name) : MetaFuncGraph(name) {} ~ListAppend() override = default; MS_DECLARE_PARENT(ListAppend, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& a_list) override; - friend std::ostream& operator<<(std::ostream& os, const ListAppend& list_append) { + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override; + friend std::ostream &operator<<(std::ostream &os, const ListAppend &list_append) { os << list_append.name_; return os; } - friend bool operator==(const ListAppend& lhs, const ListAppend& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const ListAppend &lhs, const ListAppend &rhs) { return lhs.name_ == rhs.name_; } }; using ListAppendPtr = std::shared_ptr; } // namespace prim diff --git a/mindspore/ccsrc/operator/composite/unpack_call.cc b/mindspore/ccsrc/operator/composite/unpack_call.cc index 64d6b3433b1..122f276657c 100644 --- a/mindspore/ccsrc/operator/composite/unpack_call.cc +++ b/mindspore/ccsrc/operator/composite/unpack_call.cc @@ -40,7 +40,7 @@ using mindspore::abstract::AbstractKeywordArg; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuplePtr; -FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tensor // args: tensor, slice or slice tuple const std::string op_name = std::string("UnpackCall"); @@ -70,7 +70,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ AnfNodePtr para_dict = ret_graph->add_parameter(); auto dict_elems = arg_dict->elements(); (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), - [ret_graph, para_dict](const AbstractAttribute& item) { + [ret_graph, para_dict](const AbstractAttribute &item) { auto dict_get_item = ret_graph->NewCNode( {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); return ret_graph->NewCNode( @@ -85,9 +85,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ return ret_graph; } -REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { (void)py::class_>(*m, "UnpackCall_") - .def(py::init()); + .def(py::init()); })); } // namespace prim diff --git a/mindspore/ccsrc/operator/composite/unpack_call.h b/mindspore/ccsrc/operator/composite/unpack_call.h index 7ec5f9ad33d..2f39615c1a6 100644 --- a/mindspore/ccsrc/operator/composite/unpack_call.h +++ b/mindspore/ccsrc/operator/composite/unpack_call.h @@ -40,11 +40,11 @@ namespace prim { // and generate positional parameters and key-value pairs for function. class UnpackCall : public MetaFuncGraph { public: - explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} + explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} ~UnpackCall() override = default; MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } }; using UnpackCallPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/operator/composite/zip_operation.cc b/mindspore/ccsrc/operator/composite/zip_operation.cc index b87e19b0097..4d34163f28e 100644 --- a/mindspore/ccsrc/operator/composite/zip_operation.cc +++ b/mindspore/ccsrc/operator/composite/zip_operation.cc @@ -36,7 +36,7 @@ namespace prim { using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractTuple; -FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // zip operation: // input: tuple arguments // output: tuple of items of input iterated on every input @@ -44,7 +44,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe MS_LOG(EXCEPTION) << "zip arguments input should not be empty"; } - auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr& abs) -> bool { + auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { MS_EXCEPTION_IF_NULL(abs); return abs->isa(); }); @@ -53,7 +53,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe } auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(), - [](const AbstractBasePtr& x, const AbstractBasePtr& y) { + [](const AbstractBasePtr &x, const AbstractBasePtr &y) { return (x->cast()->size() < y->cast()->size()); }); FuncGraphPtr ret_graph = std::make_shared(); @@ -81,10 +81,10 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe return ret_graph; } -REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { (void)py::class_>(*m, "ZipOperation_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/zip_operation.h b/mindspore/ccsrc/operator/composite/zip_operation.h index e1fb8d60cf0..1a3fa1f5fe9 100644 --- a/mindspore/ccsrc/operator/composite/zip_operation.h +++ b/mindspore/ccsrc/operator/composite/zip_operation.h @@ -42,15 +42,15 @@ using AbstractTuplePtr = abstract::AbstractTuplePtr; class ZipOperation : public MetaFuncGraph { public: - explicit ZipOperation(const std::string& name) : MetaFuncGraph(name) {} + explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} ~ZipOperation() override = default; MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend std::ostream& operator<<(std::ostream& os, const ZipOperation& op) { + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { os << op.name_; return os; } - friend bool operator==(const ZipOperation& lhs, const ZipOperation& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } }; using ZipOperationPtr = std::shared_ptr; } // namespace prim diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index ffd331c6c3a..9d5777641bd 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -238,7 +238,7 @@ const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); -ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) { +ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) { py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); ValuePtr node = nullptr; bool succ = parse::ConvertData(obj, &node); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index a6c614b4947..4852e2345e5 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -26,8 +26,8 @@ namespace mindspore { // namespace to support primitive operators namespace prim { -ValuePtr GetPythonOps(const std::string& op_name, - const std::string& module_name = "mindspore._extends.parse.standard_method"); +ValuePtr GetPythonOps(const std::string &op_name, + const std::string &module_name = "mindspore._extends.parse.standard_method"); // Arithmetic extern const PrimitivePtr kPrimScalarAdd; @@ -241,7 +241,7 @@ extern const PrimitivePtr kPrimVirtualDataset; class DoSignaturePrimitive : public Primitive { public: - explicit DoSignaturePrimitive(const std::string& name, const ValuePtr& function) + explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) : Primitive("S-Prim-" + name), function_(function) {} ~DoSignaturePrimitive() override = default; @@ -257,7 +257,7 @@ using DoSignaturePrimitivePtr = std::shared_ptr; class UnpackGraphPrimitive : public Primitive { public: - explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) + explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} ~UnpackGraphPrimitive() override = default; MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) diff --git a/mindspore/ccsrc/operator/prim_to_function.cc b/mindspore/ccsrc/operator/prim_to_function.cc index bdfe48157c2..733cdbdb73c 100644 --- a/mindspore/ccsrc/operator/prim_to_function.cc +++ b/mindspore/ccsrc/operator/prim_to_function.cc @@ -54,7 +54,7 @@ PrimToFunction::PrimToFunction() {"scalar_sub", kPrimTypeTwoArgs}, {"scalar_floordiv", kPrimTypeTwoArgs}}) {} -bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { +bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { bool result = false; if (func != nullptr) { @@ -79,7 +79,7 @@ bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const fu return result; } -int PrimToFunction::GetPrimType(const PrimitivePtr& prim) const { +int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { MS_EXCEPTION_IF_NULL(prim); int prim_type = static_cast(kPrimTypeUnknown); diff --git a/mindspore/ccsrc/operator/prim_to_function.h b/mindspore/ccsrc/operator/prim_to_function.h index 71518e4057b..285ab8d3abb 100644 --- a/mindspore/ccsrc/operator/prim_to_function.h +++ b/mindspore/ccsrc/operator/prim_to_function.h @@ -41,21 +41,21 @@ class PrimToFunction; class PrimToFunction { public: // Return a thread-safe singleton instance - static PrimToFunction& GetInstance() { + static PrimToFunction &GetInstance() { static PrimToFunction instance; return instance; } - PrimToFunction(const PrimToFunction&) = delete; - PrimToFunction& operator=(const PrimToFunction&) = delete; + PrimToFunction(const PrimToFunction &) = delete; + PrimToFunction &operator=(const PrimToFunction &) = delete; ~PrimToFunction() = default; // Get the args and return value for a primitive instance. - bool GetFunction(const PrimitivePtr& prim, FunctionPtr* func) const; + bool GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const; private: PrimToFunction(); // Get the number of primitive arguments - int GetPrimType(const PrimitivePtr& prim) const; + int GetPrimType(const PrimitivePtr &prim) const; const std::unordered_map prim_func_type_map_; }; } // namespace prim diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.cc b/mindspore/ccsrc/optimizer/ad/adjoint.cc index 46746b3f44e..ed89aba20e6 100644 --- a/mindspore/ccsrc/optimizer/ad/adjoint.cc +++ b/mindspore/ccsrc/optimizer/ad/adjoint.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace ad { -Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller) +Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) : primal_(primal), caller_(caller), dout_(nullptr) { if (k != nullptr) { k_ = k; @@ -43,13 +43,13 @@ Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphP AnfNodePtr Adjoint::k() { return k_; } -void Adjoint::RegisterKUser(const CNodePtr& user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } +void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } -void Adjoint::UpdateK(const AnfNodePtr& new_k) { +void Adjoint::UpdateK(const AnfNodePtr &new_k) { MS_EXCEPTION_IF_NULL(new_k); MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); // In recursive case, it needs update. - for (auto& user : k_user_) { + for (auto &user : k_user_) { MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" << new_k->ToString(); if (user.first->input(user.second) != k_) { @@ -65,11 +65,11 @@ AnfNodePtr Adjoint::primal() { return primal_; } AnfNodePtr Adjoint::dout() { return dout_hole_; } -void Adjoint::RegisterDoutUser(const CNodePtr& user, size_t index) { +void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { dout_user_.emplace_back(std::make_pair(user, index)); } -void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { +void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { if (dout_ != nullptr) { MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); auto add = prim::GetPythonOps("hyper_add"); @@ -81,7 +81,7 @@ void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { void Adjoint::CallDoutHole() { if (dout_ != nullptr) { - for (auto& user : dout_user_) { + for (auto &user : dout_user_) { MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " << dout_->ToString(); if (user.first->input(user.second) != dout_hole_) { diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.h b/mindspore/ccsrc/optimizer/ad/adjoint.h index 673928129b7..b2dae8e66f1 100644 --- a/mindspore/ccsrc/optimizer/ad/adjoint.h +++ b/mindspore/ccsrc/optimizer/ad/adjoint.h @@ -28,15 +28,15 @@ namespace mindspore { namespace ad { class Adjoint { public: - Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller); + Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); ~Adjoint() = default; AnfNodePtr primal(); AnfNodePtr k(); - void UpdateK(const AnfNodePtr& k); - void RegisterKUser(const CNodePtr& user, size_t index); + void UpdateK(const AnfNodePtr &k); + void RegisterKUser(const CNodePtr &user, size_t index); AnfNodePtr dout(); - void AccumulateDout(const AnfNodePtr& dout_factor); - void RegisterDoutUser(const CNodePtr& user, size_t index); + void AccumulateDout(const AnfNodePtr &dout_factor); + void RegisterDoutUser(const CNodePtr &user, size_t index); void CallDoutHole(); private: diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc index 9e713d34255..fe11191546c 100644 --- a/mindspore/ccsrc/optimizer/clean.cc +++ b/mindspore/ccsrc/optimizer/clean.cc @@ -36,7 +36,7 @@ using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractTuple; -static AbstractBasePtr Reabs(const AbstractBasePtr& t) { +static AbstractBasePtr Reabs(const AbstractBasePtr &t) { if (t == nullptr) { return nullptr; } @@ -47,14 +47,14 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { AbstractBasePtrList baselist; auto attributes = abs_class->attributes(); (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), - [](const AbstractAttribute& item) { return item.second; }); + [](const AbstractAttribute &item) { return item.second; }); res = std::make_shared(baselist); } else if (t->isa()) { auto abs_dict = dyn_cast(t); AbstractBasePtrList baselist; auto elements = abs_dict->elements(); (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), - [](const AbstractAttribute& item) { return item.second; }); + [](const AbstractAttribute &item) { return item.second; }); res = std::make_shared(baselist); } else if (t->isa()) { auto abs_dict = dyn_cast(t); @@ -63,11 +63,11 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { return res; } -AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [getattr, data, attribute] MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); @@ -86,9 +86,9 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); - const auto& cmap = ct->attributes(); + const auto &cmap = ct->attributes(); int count = 0; - for (auto& item : cmap) { + for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } @@ -102,12 +102,12 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } -AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); // Inputs should be [dict_getitem, dict, item] - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); AnfNodePtr data = inputs[1]; @@ -124,9 +124,9 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); - const auto& cmap = ct->elements(); + const auto &cmap = ct->elements(); int count = 0; - for (auto& item : cmap) { + for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } @@ -139,7 +139,7 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } -AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { +AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); @@ -150,11 +150,11 @@ AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { return node->func_graph()->NewCNode(inputs); } -AnfNodePtr ErasePartialNode(const CNodePtr& node) { +AnfNodePtr ErasePartialNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); @@ -178,7 +178,7 @@ AnfNodePtr ErasePartialNode(const CNodePtr& node) { return nullptr; } -AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { +AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); @@ -189,11 +189,11 @@ AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { return node->func_graph()->NewCNode(inputs); } -AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [list_getitem, list, item] if (inputs.size() < 3) { MS_LOG(EXCEPTION) << "Node's input number < 3."; @@ -208,11 +208,11 @@ AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); } -AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { +AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [list_setitem, list, index, item] if (inputs.size() < 4) { MS_LOG(EXCEPTION) << "Node's input number < 4."; @@ -225,36 +225,36 @@ AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); } -AnfNodePtr EraseMakeDictNode(const CNodePtr& node) { +AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); return inputs[2]; } -AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr& node) { +AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [make_keyword_arg, key, value] MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); return inputs[2]; } -AnfNodePtr EraseExtractKeywordArg(const CNodePtr& node) { +AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [extract_keyword_arg, arg, key] MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); return inputs[2]; } -ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int depth) { +ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { const int DEPTH_MAX = 5; if (depth > DEPTH_MAX) { MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; } std::vector elements; - for (const auto& it : value_list->value()) { + for (const auto &it : value_list->value()) { ValuePtr value = nullptr; if (it->isa()) { value = ConvertValueListToValueTuple(it->cast(), depth + 1); @@ -266,7 +266,7 @@ ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int d return std::make_shared(elements); } -AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { +AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(node); ValuePtr value = node->value(); auto value_list = value->cast(); @@ -278,13 +278,13 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { // Convert class to Tuple // Convert getattr to getitem // Convert make_record to make_tuple -void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { +void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var AnfNodeSet all_node = manager->all_nodes(); - for (auto& node : all_node) { + for (auto &node : all_node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); AnfNodePtr new_node = nullptr; @@ -320,20 +320,20 @@ void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& } } - for (auto& node : manager->all_nodes()) { + for (auto &node : manager->all_nodes()) { auto ret = Reabs(node->abstract()); node->set_abstract(ret); } } // expand tuples in graph parameters -static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, const FuncGraphPtr& func_graph, - const std::vector& params) { +static std::vector ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, + const std::vector ¶ms) { MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(func_graph); std::vector new_params; - for (const auto& param : params) { + for (const auto ¶m : params) { MS_EXCEPTION_IF_NULL(param); auto param_abs = param->abstract(); MS_EXCEPTION_IF_NULL(param_abs); @@ -350,7 +350,7 @@ static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, con std::vector new_param; std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; auto abs_tuple = dyn_cast(param_abs); - for (auto& elem : abs_tuple->elements()) { + for (auto &elem : abs_tuple->elements()) { auto np = std::make_shared(func_graph); np->set_abstract(elem); new_param.emplace_back(np); @@ -366,11 +366,11 @@ static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, con } // expand tuples in graph applies -static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const std::vector& inputs) { +static std::vector ExpandTuplesC(const FuncGraphPtr &graph, const std::vector &inputs) { MS_EXCEPTION_IF_NULL(graph); std::vector new_inputs; - for (const auto& input : inputs) { + for (const auto &input : inputs) { MS_EXCEPTION_IF_NULL(input); auto input_abs = input->abstract(); @@ -391,7 +391,7 @@ static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const st int idx = 0; std::vector new_input; auto abs_tuple = dyn_cast(input_abs); - for (auto& elem : abs_tuple->elements()) { + for (auto &elem : abs_tuple->elements()) { auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); AbstractBasePtr aptr = std::make_shared(std::make_shared(idx)); c_node->input(2)->set_abstract(aptr); @@ -416,19 +416,19 @@ static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const st // tuples in Graph's parameters: AbstractTuple (a, b, c) --> // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) // cppcheck-suppress unusedFunction -void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { +void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var AnfNodeSet all_node = manager->all_nodes(); - for (auto& node : all_node) { + for (auto &node : all_node) { auto cnode = node->cast(); if (cnode == nullptr) { continue; } - const auto& inputs = cnode->inputs(); + const auto &inputs = cnode->inputs(); // Bypass the first input in inputs as it's fn. if (!IsValueNode(inputs[0])) { @@ -466,7 +466,7 @@ void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { } FuncGraphSet all_graph = manager->func_graphs(); - for (auto& func_graph : all_graph) { + for (auto &func_graph : all_graph) { MS_EXCEPTION_IF_NULL(func_graph); auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); manager->SetParameters(func_graph, expand_p); diff --git a/mindspore/ccsrc/optimizer/control_depend.h b/mindspore/ccsrc/optimizer/control_depend.h index 2a51a247180..076e2c02294 100644 --- a/mindspore/ccsrc/optimizer/control_depend.h +++ b/mindspore/ccsrc/optimizer/control_depend.h @@ -22,7 +22,7 @@ namespace mindspore { namespace opt { // Automatically adding control depend based on effect order and side effect analysis. -void AddControlDepend(const FuncGraphPtr& graph); +void AddControlDepend(const FuncGraphPtr &graph); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc index 5daeced3a5e..32a42bc16b0 100644 --- a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc +++ b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc @@ -44,7 +44,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, Func nodes.push_back(func_node); // {unpackcall, {GradOperation, ...}, args...} std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr& node) { return node; }); + [](const AnfNodePtr &node) { return node; }); unpack_graph_node = func_graph->NewCNode(nodes); } else { auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); @@ -52,14 +52,14 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, Func nodes.push_back(func_node); // {{GradOperation, ...}, args...} std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr& node) { return node; }); + [](const AnfNodePtr &node) { return node; }); unpack_graph_node = func_graph->NewCNode(nodes); } return unpack_graph_node; } // get metagraph of value node -MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { +MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { ValuePtr value; if (IsValueNode(node)) { value = GetValueNode(node)->cast()->function(); @@ -73,7 +73,7 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { } // check if node is a specific metafuncgraph op -bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { +bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { if (node != nullptr) { auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); if (meta_func_graph_ptr == nullptr) { @@ -89,7 +89,7 @@ bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_gr // {{GradOperation, g, w}, Ys} // {UnPackCall, {GradOperation, g, w}, Ys} -AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { +AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 24339ddb845..0dbaf1107fe 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -31,20 +31,20 @@ namespace mindspore { /* namespace to support opt */ namespace opt { -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim, - const RenormAction& renorm_action) { - auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, + const RenormAction &renorm_action) { + auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const std::vector& prims, const RenormAction& renorm_action) { - auto fn = [prims](const AnfNodePtr& node) -> bool { +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, + const std::vector &prims, const RenormAction &renorm_action) { + auto fn = [prims](const AnfNodePtr &node) -> bool { if (!node->isa()) { return false; } - for (auto& prim : prims) { + for (auto &prim : prims) { if (IsPrimitiveCNode(node, prim)) { return true; } @@ -55,12 +55,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std:: return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const PredicateFuncType& predicate, const RenormAction& renorm_action) { +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, + const PredicateFuncType &predicate, const RenormAction &renorm_action) { return std::make_shared(transform, name, predicate, renorm_action); } -AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { +AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { #ifdef ENABLE_PROFILE double t = GetTime(); #endif @@ -88,8 +88,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode return result; } -bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNodePtr& root_node, - const SubstitutionPtr& transform) const { +bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, + const SubstitutionPtr &transform) const { FuncGraphManagerPtr manager = optimizer->manager(); std::unordered_set seen_node; std::deque todo{root_node}; @@ -131,13 +131,13 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo } if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); } - auto& node_users = manager->node_users(); + auto &node_users = manager->node_users(); if (change && node_users.find(node) != node_users.end()) { - for (auto& use : node_users[node]) { + for (auto &use : node_users[node]) { auto use_node = use.first; todo.push_back(use_node); if (seen_node.find(use_node) != seen_node.end()) { @@ -152,7 +152,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo return changes; } -bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const OptimizerPtr& optimizer) const { +bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = optimizer->manager(); @@ -163,7 +163,7 @@ bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const Optimize do { loop = false; - for (auto const& transform : list_) { + for (auto const &transform : list_) { auto change = ApplyTransform(optimizer, func_graph->output(), transform); changes = changes || change; loop = loop || change; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc index 03f7d054e05..b4f4cb5b228 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t recursive_times = 0) { +std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -39,7 +39,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t MS_EXCEPTION_IF_NULL(manager); auto node_set = manager->node_users()[para]; std::unordered_set cnode_set; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { auto cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -54,7 +54,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t (void)cnode_set.emplace(cnode); } else { auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); - for (auto& cnode_sub : cnode_set_sub) { + for (auto &cnode_sub : cnode_set_sub) { (void)cnode_set.emplace(cnode_sub); } } @@ -63,8 +63,8 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t } Status AllreduceFusion::AddNodeToGraph() { - const auto& parameters = root_graph_->parameters(); - for (auto& parameter : parameters) { + const auto ¶meters = root_graph_->parameters(); + for (auto ¶meter : parameters) { if (!ParameterRequireGrad(parameter)) { continue; } @@ -72,7 +72,7 @@ Status AllreduceFusion::AddNodeToGraph() { if (cnode_set.empty()) { continue; } - for (auto& cnode : cnode_set) { + for (auto &cnode : cnode_set) { MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); @@ -83,7 +83,7 @@ Status AllreduceFusion::AddNodeToGraph() { return SUCCESS; } -CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursive_times) const { +CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -110,30 +110,30 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursi return cnode_dist; } else { auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); - for (auto& ele : cnode_dist_next) { + for (auto &ele : cnode_dist_next) { cnode_dist[ele.first] = cost + ele.second; } } } else { auto cnode_dist_next = FindNextCNodes(cnode); - for (auto& ele : cnode_dist_next) { + for (auto &ele : cnode_dist_next) { cnode_dist[ele.first] = ele.second; } } return cnode_dist; } -CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recursive_times) const { +CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; } - const auto& from_inputs = from->inputs(); + const auto &from_inputs = from->inputs(); std::unordered_map dist_map; MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; - for (auto& input_node : from_inputs) { + for (auto &input_node : from_inputs) { auto cnode_dist = FindCNode(input_node, recursive_times + 1); - for (auto& ele : cnode_dist) { + for (auto &ele : cnode_dist) { (void)dist_map.emplace(ele); } } @@ -142,11 +142,11 @@ CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recu Status AllreduceFusion::AddEdgeToGraph() { std::unordered_map cnode_state_map; - const auto& cnodes = allreduce_graph_.cnode_set(); - for (auto& cnode : cnodes) { + const auto &cnodes = allreduce_graph_.cnode_set(); + for (auto &cnode : cnodes) { cnode_state_map[cnode] = 0; } - const auto& head_cnode = allreduce_graph_.head_cnode(); + const auto &head_cnode = allreduce_graph_.head_cnode(); std::queue cnode_queue; cnode_queue.emplace(head_cnode); cnode_state_map[head_cnode] = 1; @@ -156,9 +156,9 @@ Status AllreduceFusion::AddEdgeToGraph() { cnode_queue.pop(); cnode_state_map[cur_cnode] = 2; auto next = FindNextCNodes(cur_cnode); - for (auto& ele : next) { - auto& cnode = ele.first; - auto& dist = ele.second; + for (auto &ele : next) { + auto &cnode = ele.first; + auto &dist = ele.second; if (cnode_state_map[cnode] == 0) { cnode_queue.emplace(cnode); cnode_state_map[cnode] = 1; @@ -173,7 +173,7 @@ Status AllreduceFusion::AddEdgeToGraph() { return SUCCESS; } -std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_times = 0) { +std::vector FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -184,7 +184,7 @@ std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_time MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[para]; std::vector cnode_list; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { auto cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -210,7 +210,7 @@ std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_time return cnode_list; } -void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::string& parameter_name) { +void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { MS_EXCEPTION_IF_NULL(mirror_cnode); MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; auto node_prim = GetValueNode(mirror_cnode->input(0)); @@ -227,14 +227,14 @@ void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::st (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared(parameter_name))); } -Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { +Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { auto mirror_cnodes = FindMirror(para); if (mirror_cnodes.empty()) { MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; return SUCCESS; } if (mirror_cnodes.size() > 2) { - for (auto& mirror_cnode : mirror_cnodes) { + for (auto &mirror_cnode : mirror_cnodes) { MS_EXCEPTION_IF_NULL(mirror_cnode); MS_LOG(INFO) << mirror_cnode->DebugString(); } @@ -243,15 +243,15 @@ Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { << "Mirror CNode found."; return FAILED; } - for (auto& mirror_cnode : mirror_cnodes) { + for (auto &mirror_cnode : mirror_cnodes) { auto parameter_name = ParameterName(para); SetMirrorFusion(mirror_cnode, fusion, parameter_name); } return SUCCESS; } -Status FindMirrorAndSetFusion(const std::vector& paras, int32_t fusion) { - for (auto& param_node : paras) { +Status FindMirrorAndSetFusion(const std::vector ¶s, int32_t fusion) { + for (auto ¶m_node : paras) { if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; return FAILED; @@ -260,7 +260,7 @@ Status FindMirrorAndSetFusion(const std::vector& paras, int32_t fusi return SUCCESS; } -Status AllreduceFusion::SetFusion(const std::vector& cost_map) { +Status AllreduceFusion::SetFusion(const std::vector &cost_map) { if (cost_map.size() < 2) { MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); return FAILED; @@ -386,7 +386,7 @@ Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { return SetFusionByBackwardCompAndAllreduceTime(); } -Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) { +Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { if (ret == nullptr) { MS_LOG(ERROR) << "ret is nullptr."; return FAILED; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h index 67dc55836a4..43a99350954 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h @@ -50,15 +50,15 @@ class AllreduceFusion { allreduce_bandwidth_(0), computation_time_parameter_(0) {} virtual ~AllreduceFusion() = default; - Status ProcessAllreduceFusion(const CNodePtr& ret); + Status ProcessAllreduceFusion(const CNodePtr &ret); private: Status AddNodeToGraph(); - CNodeCostMap FindCNode(const AnfNodePtr& from, uint32_t recursive_times = 0) const; - CNodeCostMap FindNextCNodes(const CNodePtr& from, uint32_t recursive_times = 0) const; + CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; + CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; Status AddEdgeToGraph(); std::vector GenerateCostMap(int32_t fusion_times, double tail_percent) const; - Status SetFusion(const std::vector& cost_map); + Status SetFusion(const std::vector &cost_map); Status SetFusionByAlgorithm(int32_t algorithm); Status SetFusionByBackwardCompTime(); Status SetFusionByBackwardCompAndAllreduceTime(); diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc index 9e04593c831..2a98a38add3 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace parallel { -Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { +Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { AllreduceNodePtr arnode; auto cnode_emplace_return = cnode_set_.emplace(node); if (!cnode_emplace_return.second) { @@ -64,7 +64,7 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { return SUCCESS; } -Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double dist) { +Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { auto from_arnode_iter = cnode_arnode_map_.find(from); if (from_arnode_iter == cnode_arnode_map_.end()) { MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; @@ -94,14 +94,14 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double return SUCCESS; } -bool AllreduceGraph::NodeInGraph(const CNodePtr& node) const { +bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { auto cnode_iter = cnode_set_.find(node); return !(cnode_iter == cnode_set_.end()); } std::vector AllreduceGraph::GetParaByCost(double from, double to) { std::vector nodes; - for (auto& cnode_arnode : cnode_arnode_map_) { + for (auto &cnode_arnode : cnode_arnode_map_) { MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() << " curr_para_size: " << cnode_arnode.second->curr_para_size(); @@ -117,7 +117,7 @@ std::pair, double> AllreduceGraph::GetParaByParaSize(dou std::vector nodes; double cur_para_size = 0; double from = to; - for (auto& arnode : arnode_vec_) { + for (auto &arnode : arnode_vec_) { if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { continue; } @@ -135,14 +135,14 @@ std::pair, double> AllreduceGraph::GetParaByParaSize(dou void AllreduceGraph::PrintCNodeSet() const { MS_LOG(INFO) << "CNodeSet:"; - for (auto& cnode : cnode_set_) { + for (auto &cnode : cnode_set_) { MS_LOG(INFO) << cnode->DebugString(); } } void AllreduceGraph::PrintAllredueGraphInfo() const { MS_LOG(INFO) << "max: " << max_; - for (auto& cnode_arnode : cnode_arnode_map_) { + for (auto &cnode_arnode : cnode_arnode_map_) { MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); MS_LOG(INFO) << "arnode info: "; cnode_arnode.second->ToString(); @@ -151,21 +151,21 @@ void AllreduceGraph::PrintAllredueGraphInfo() const { void AllreduceGraph::PrintArnodeVec() const { MS_LOG(INFO) << "ArnodeVec:"; - for (auto& arnode : arnode_vec_) { + for (auto &arnode : arnode_vec_) { arnode.ToString(); } } void AllreduceGraph::PrintArnodeSet() const { MS_LOG(INFO) << "ArnodeSet:"; - for (auto& arnode : arnode_set_) { + for (auto &arnode : arnode_set_) { arnode->ToString(); } } void AllreduceGraph::SortArnode() { arnode_vec_.clear(); - for (auto& node : arnode_set_) { + for (auto &node : arnode_set_) { arnode_vec_.emplace_back(*node); } std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); @@ -173,8 +173,8 @@ void AllreduceGraph::SortArnode() { Status AllreduceGraph::RemoveExtraParas() { std::unordered_set para_map; - for (auto& node : arnode_vec_) { - for (auto& para : node.paras()) { + for (auto &node : arnode_vec_) { + for (auto ¶ : node.paras()) { auto emplac_result = para_map.emplace(para); if (!emplac_result.second) { MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; @@ -188,7 +188,7 @@ Status AllreduceGraph::RemoveExtraParas() { return SUCCESS; } -Status AllreduceGraph::set_head_cnode(const CNodePtr& node) { +Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { auto arnode = std::make_shared(AllreduceNode()); if (arnode->Init(node) != SUCCESS) { MS_LOG(ERROR) << "AllreduceNode Init failed"; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h index f0db78a1308..b2084b735cb 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h @@ -42,9 +42,9 @@ class AllreduceGraph { cnode_arnode_map_(), max_(0) {} virtual ~AllreduceGraph() = default; - Status AddNode(const CNodePtr& node, const AnfNodePtr& para); - Status AddEdge(const CNodePtr& from, const CNodePtr& to, double dist); - bool NodeInGraph(const CNodePtr& node) const; + Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); + Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); + bool NodeInGraph(const CNodePtr &node) const; std::vector GetParaByCost(double from, double to); // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is // over para_size. @@ -60,9 +60,9 @@ class AllreduceGraph { void PrintAllredueGraphInfo() const; void PrintArnodeVec() const; void PrintArnodeSet() const; - const std::unordered_set& cnode_set() const { return cnode_set_; } + const std::unordered_set &cnode_set() const { return cnode_set_; } CNodePtr head_cnode() const { return head_cnode_; } - Status set_head_cnode(const CNodePtr& node); + Status set_head_cnode(const CNodePtr &node); double max() const { return max_; } private: diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc index 6be588928a2..113d4ec59b6 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -Status AllreduceNode::AddNext(const AllreduceNodePtr& next_node) { +Status AllreduceNode::AddNext(const AllreduceNodePtr &next_node) { if (next_node == nullptr) { MS_LOG(ERROR) << "next_node is nullptr!"; return FAILED; @@ -30,7 +30,7 @@ Status AllreduceNode::AddNext(const AllreduceNodePtr& next_node) { return SUCCESS; } -Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, double* max) { +Status AllreduceNode::AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max) { if (prev_node == nullptr) { MS_LOG(ERROR) << "next_node is nullptr!"; return FAILED; @@ -46,7 +46,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do *max = depend_feat_size_; } std::queue next_queue; - for (auto& next : next_) { + for (auto &next : next_) { next_queue.push(next); } while (!next_queue.empty()) { @@ -55,7 +55,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do if (ele->depend_feat_size() > *max) { *max = ele->depend_feat_size(); } - for (auto& next : ele->next()) { + for (auto &next : ele->next()) { next_queue.push(next); } next_queue.pop(); @@ -63,7 +63,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do return SUCCESS; } -Status AllreduceNode::Init(const CNodePtr& cnode_ptr) { +Status AllreduceNode::Init(const CNodePtr &cnode_ptr) { if (cnode_ptr == nullptr) { MS_LOG(ERROR) << "cnode_ptr is nullptr!"; return FAILED; @@ -72,7 +72,7 @@ Status AllreduceNode::Init(const CNodePtr& cnode_ptr) { return SUCCESS; } -Status AllreduceNode::AddPara(const AnfNodePtr& node_ptr) { +Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { if (node_ptr == nullptr) { MS_LOG(ERROR) << "node_ptr is nullptr!"; return FAILED; @@ -99,7 +99,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr& node_ptr) { return SUCCESS; } -Status AllreduceNode::RemovePara(const AnfNodePtr& node_ptr) { +Status AllreduceNode::RemovePara(const AnfNodePtr &node_ptr) { if (node_ptr == nullptr) { MS_LOG(ERROR) << "node_ptr is nullptr!"; return FAILED; @@ -115,7 +115,7 @@ Status AllreduceNode::RemovePara(const AnfNodePtr& node_ptr) { void AllreduceNode::ToString() const { MS_LOG(INFO) << "cnode: " << cnode_ptr_->DebugString() << "para size: " << paras_.size(); - for (auto& para : paras_) { + for (auto ¶ : paras_) { MS_LOG(INFO) << "para name: " << para->fullname_with_scope() << " size: " << para_size_map_.at(para); } MS_LOG(INFO) << "depend_feat_size: " << depend_feat_size_ << " curr_para_size: " << curr_para_size_; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h index d9ba98c3a27..db1c4e3f2ef 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h @@ -33,23 +33,23 @@ class AllreduceNode { public: AllreduceNode() : cnode_ptr_(nullptr), prev_(), next_(), paras_(), para_size_map_(), curr_para_size_(0), depend_feat_size_(0) {} - Status Init(const CNodePtr& cnode_ptr); - Status AddPara(const AnfNodePtr& node_ptr); - Status RemovePara(const AnfNodePtr& node_ptr); - const std::unordered_set& paras() const { return paras_; } + Status Init(const CNodePtr &cnode_ptr); + Status AddPara(const AnfNodePtr &node_ptr); + Status RemovePara(const AnfNodePtr &node_ptr); + const std::unordered_set ¶s() const { return paras_; } double curr_para_size() const { return curr_para_size_; } virtual ~AllreduceNode() = default; // Add previous node // prev_node is the previous to be added // max is the current max depend_feat_size of the AllreduceGraph - Status AddPrev(const AllreduceNodePtr& prev_node, double dist, double* max); - Status AddNext(const AllreduceNodePtr& next_node); + Status AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max); + Status AddNext(const AllreduceNodePtr &next_node); double depend_feat_size() const { return depend_feat_size_; } void AddDependFeatSize(double add_dist) { depend_feat_size_ += add_dist; } - const std::vector& next() const { return next_; } + const std::vector &next() const { return next_; } void ToString() const; - bool operator<(const AllreduceNode& node) const { return depend_feat_size_ < node.depend_feat_size(); } - bool operator>(const AllreduceNode& node) const { return depend_feat_size_ > node.depend_feat_size(); } + bool operator<(const AllreduceNode &node) const { return depend_feat_size_ < node.depend_feat_size(); } + bool operator>(const AllreduceNode &node) const { return depend_feat_size_ > node.depend_feat_size(); } private: CNodePtr cnode_ptr_; diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc index 190f589bb5d..ad3a3a1298f 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { -void Simplify(CostPtrList* clist_ptrs) { +void Simplify(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method // excludes the cost with greater computation_cost_ and greater communication_cost. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} @@ -44,7 +44,7 @@ void Simplify(CostPtrList* clist_ptrs) { *clist_ptrs = std::move(ret); } -void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { +void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. if (!COST_MODEL_SIMPLIFY_CALCULATION) { @@ -66,7 +66,7 @@ void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { *clist_ptrs = std::move(ret); } -void RefineForPracticalCost(const CostPtr& origin_cost, bool is_redistribution) { +void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { MS_EXCEPTION_IF_NULL(origin_cost); if (is_redistribution) { // Redistribution cost diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h index 9e9003848b5..2cb24dd7f36 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h @@ -44,7 +44,7 @@ using RedistributionOpListPtr = std::shared_ptr& decision_ = nullptr) + Cost(double computation, double commuication, const std::shared_ptr &decision_ = nullptr) : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { memory_with_reuse_ = 0.0; communication_without_parameter_ = 0.0; @@ -76,8 +76,8 @@ class StrategyWithCost { StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} - StrategyWithCost(const StrategyWithCost& swc) = delete; - StrategyWithCost(StrategyWithCost&& swc) + StrategyWithCost(const StrategyWithCost &swc) = delete; + StrategyWithCost(StrategyWithCost &&swc) : strategy_ptr(swc.strategy_ptr), inputs_ptr(swc.inputs_ptr), outputs_ptr(swc.outputs_ptr), @@ -295,9 +295,9 @@ using StarEliminationDecisionPtr = std::shared_ptr; using FinalDecisionPtr = std::shared_ptr; using FinalSingleDecisionPtr = std::shared_ptr; -void Simplify(CostPtrList* clist); -void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist); -void RefineForPracticalCost(const CostPtr&, bool is_redistribution); +void Simplify(CostPtrList *clist); +void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist); +void RefineForPracticalCost(const CostPtr &, bool is_redistribution); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc index dd21096fcc4..8d439f15228 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { -Status GetStrategy(const CostGraphPtr& graph) { +Status GetStrategy(const CostGraphPtr &graph) { MS_LOG(INFO) << "Searching strategies begins."; MS_EXCEPTION_IF_NULL(graph); std::vector eliminations; @@ -141,7 +141,7 @@ Status RecoverStrategy(std::vector eliminations) { auto elimination = (*rit)->cast(); auto new_edge = elimination->new_edge_; MS_EXCEPTION_IF_NULL(new_edge); - auto& edges = elimination->edges_; + auto &edges = elimination->edges_; auto decision = new_edge->selected_cost()->decision_ptr_->cast(); for (size_t j = 0; j < edges.size(); ++j) { MS_EXCEPTION_IF_NULL(edges[j]); diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h index 6d43218e19f..efedba7d105 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h @@ -65,7 +65,7 @@ struct OpElimination : public Elimination { // Edge Elimination struct EdgeElimination : public Elimination { - EdgeElimination(const EdgePtr& n_edge, std::vector eds) + EdgeElimination(const EdgePtr &n_edge, std::vector eds) : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} std::vector edges_; @@ -139,7 +139,7 @@ using TriangleEliminationPtr = std::shared_ptr; using StarEliminationPtr = std::shared_ptr; // Phase 1 and Phase 2 -Status GetStrategy(const CostGraphPtr& graph); +Status GetStrategy(const CostGraphPtr &graph); // Phase 3 Status RecoverStrategy(std::vector eliminations); diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc index 21e67f9f7b2..6973830779b 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc @@ -28,19 +28,19 @@ namespace mindspore { namespace parallel { Status Edge::InitEdgeCost() { bool has_available_cost = false; - for (auto& swc : prev_op_->GetStrategyCost()) { + for (auto &swc : prev_op_->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(swc); pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); } - for (auto& swc : next_op_->GetStrategyCost()) { + for (auto &swc : next_op_->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(swc); next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr)); } if (is_identity_edge) { - for (auto& target_output : pre_op_output_) { + for (auto &target_output : pre_op_output_) { auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); auto target_output_str = target_output.first; - for (auto& target_input : next_op_input_) { + for (auto &target_input : next_op_input_) { auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); auto target_input_str = target_input.first; if (target_output_lyt == target_input_lyt) { @@ -57,12 +57,12 @@ Status Edge::InitEdgeCost() { } } } else { - for (auto& target_output : pre_op_output_) { + for (auto &target_output : pre_op_output_) { auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); auto target_output_str = target_output.first; auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_]; auto type = prev_op_->outputs_type()[prev_op_output_index_]; - for (auto& target_input : next_op_input_) { + for (auto &target_input : next_op_input_) { auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); auto target_input_str = target_input.first; CostPtr cost; @@ -99,8 +99,8 @@ Status Edge::InitEdgeCost() { return Status::SUCCESS; } -Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, const TensorLayout& next_op_input_layout, - size_t type_length, TypePtr type, CostPtr* cost) { +Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t type_length, TypePtr type, CostPtr *cost) { MS_EXCEPTION_IF_NULL(prev_op_); MS_EXCEPTION_IF_NULL(cost); RankList dev_list = prev_op_->global_device_list(); @@ -148,9 +148,9 @@ CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) { return result; } -CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr, const std::vector& edges, - const StrategyPtr& input_st_ptr) { - std::function LocalGetCostList = [&](const EdgePtr& edge) { +CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector &edges, + const StrategyPtr &input_st_ptr) { + std::function LocalGetCostList = [&](const EdgePtr &edge) { MS_EXCEPTION_IF_NULL(edge); return edge->GetCostList(output_st_ptr, input_st_ptr); }; @@ -174,7 +174,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr result.push_back(new_cost); return; } - for (auto& c : all_cost_list[k]) { + for (auto &c : all_cost_list[k]) { MS_EXCEPTION_IF_NULL(c); selected_cost_list[k] = c; recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, @@ -187,11 +187,11 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr return result; } -void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector& edges, OperatorInfoPtr) { +void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector &edges, OperatorInfoPtr) { bool valid = false; - for (const auto& output_pair : pre_op_output_) { + for (const auto &output_pair : pre_op_output_) { StrategyPtr output_st_ptr = output_pair.first; - for (const auto& input_pair : next_op_input_) { + for (const auto &input_pair : next_op_input_) { StrategyPtr input_st_ptr = input_pair.first; CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr); CostPtrKey key = {output_st_ptr, input_st_ptr}; @@ -206,14 +206,14 @@ void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector } } -void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& left_cost_list, - const CostPtrList& middle_cost_list, const CostPtrList& right_cost_list, - CostPtrList* ret_cost_list) { - for (auto& left_cost : left_cost_list) { +void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list) { + for (auto &left_cost : left_cost_list) { MS_EXCEPTION_IF_NULL(left_cost); - for (auto& middle_cost : middle_cost_list) { + for (auto &middle_cost : middle_cost_list) { MS_EXCEPTION_IF_NULL(middle_cost); - for (auto& right_cost : right_cost_list) { + for (auto &right_cost : right_cost_list) { MS_EXCEPTION_IF_NULL(right_cost); double computation = left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; @@ -238,14 +238,14 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr } } -CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr& e1, const StrategyPtr& output_st_ptr, - const OperatorInfoPtr& op, const EdgePtr& e2, - const StrategyPtr& input_st_ptr) { +CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr, + const OperatorInfoPtr &op, const EdgePtr &e2, + const StrategyPtr &input_st_ptr) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(e1); MS_EXCEPTION_IF_NULL(e2); CostPtrList result; - for (const auto& op_strategy : op->GetStrategyCost()) { + for (const auto &op_strategy : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_strategy); auto middle_strategy = op_strategy->strategy_ptr; CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), @@ -255,11 +255,11 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr& e1, const StrategyP return result; } -void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op, const EdgePtr& e2) { +void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) { bool valid = false; - for (const auto& output_pair : pre_op_output_) { + for (const auto &output_pair : pre_op_output_) { StrategyPtr output_st_ptr = output_pair.first; - for (const auto& input_pair : next_op_input_) { + for (const auto &input_pair : next_op_input_) { StrategyPtr input_st_ptr = input_pair.first; CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr); @@ -283,8 +283,8 @@ Status Edge::CalculateMemoryCost() { if (is_output_parameter_involve_ == 0) { // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is // unnecessary to keep them in memory. - for (auto& cost_kv : cost_map_) { - auto& cost_v = cost_kv.second; + for (auto &cost_kv : cost_map_) { + auto &cost_v = cost_kv.second; if (!cost_v.empty()) { cost_v[0]->memory_with_reuse_ = 0; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h index f9741257493..e760c24c345 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h @@ -37,9 +37,9 @@ using EdgePtr = std::shared_ptr; class Edge { // An 'Edge' connects two Operators in the CostGraph. public: - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const size_t& output_index_, const size_t& input_index_, - const bool& is_com) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -49,9 +49,9 @@ class Edge { is_identity_edge = false; } - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const size_t& output_index_, const size_t& input_index_, - const bool& is_com, const bool& is_iden) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com, const bool &is_iden) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -60,9 +60,9 @@ class Edge { is_combined_(is_com), is_identity_edge(is_iden) {} - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const std::vector& output_indexs_, - const std::vector& input_indexs_, const bool& is_com) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const std::vector &output_indexs_, + const std::vector &input_indexs_, const bool &is_com) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -83,13 +83,13 @@ class Edge { // For two operators u--->v, given the output tensor layout of u, // and the input tensor layout of v, return the redistribution cost, // and the op_list to carry out the redistribution. - Status GetRedistributionCost(const TensorLayout& prev_op_output_layout, const TensorLayout& next_op_input_layout, - size_t, TypePtr type, CostPtr* cost); + Status GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t, TypePtr type, CostPtr *cost); - void set_pre_op_output(const std::vector, std::vector>>& output_set) { + void set_pre_op_output(const std::vector, std::vector>> &output_set) { pre_op_output_ = output_set; } - void set_next_op_input(const std::vector, std::vector>>& input_set) { + void set_next_op_input(const std::vector, std::vector>> &input_set) { next_op_input_ = input_set; } @@ -109,27 +109,27 @@ class Edge { std::vector prev_op_output_indexs() const { return pre_op_output_indexs_; } std::vector next_op_input_indexs() const { return next_op_input_indexs_; } - CostPtrList CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr, - const std::vector>& edges, - const StrategyPtr& input_st_ptr); + CostPtrList CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, + const std::vector> &edges, + const StrategyPtr &input_st_ptr); // In the Edge Elimination operation in DP algorithm, 'edges' is replaced by a new edge. This method is used to // set cost for this new edge - void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector>& edges, + void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector> &edges, std::shared_ptr v); - void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& left_cost_list, - const CostPtrList& middle_cost_list, const CostPtrList& right_cost_list, - CostPtrList* ret_cost_list); + void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list); - CostPtrList CreateOpEliminationCostList(const std::shared_ptr& e1, const StrategyPtr& output_st_ptr, - const std::shared_ptr& op, const std::shared_ptr& e2, - const StrategyPtr& input_st_ptr); + CostPtrList CreateOpEliminationCostList(const std::shared_ptr &e1, const StrategyPtr &output_st_ptr, + const std::shared_ptr &op, const std::shared_ptr &e2, + const StrategyPtr &input_st_ptr); // In the Operation Elimination operation in DP algorithm, 'op', 'e1' and 'e2' are replaced by a new edge. // This method is used to set cost for this new edge - void OpEliminationSetNewCost(const std::shared_ptr& e1, const std::shared_ptr& op, - const std::shared_ptr& e2); + void OpEliminationSetNewCost(const std::shared_ptr &e1, const std::shared_ptr &op, + const std::shared_ptr &e2); - void set_selected_cost(const CostPtr& cost) { selected_cost_ = cost; } - const CostPtr& selected_cost() const { return selected_cost_; } + void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } + const CostPtr &selected_cost() const { return selected_cost_; } void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index c56d3a6fbd7..501a983a957 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -144,7 +144,7 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { } } -void CostGraph::RemoveOperator(const OperatorInfoPtr& op) { +void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { for (auto it = ops_.begin(); it != ops_.end();) { if ((*it) == op) { it = ops_.erase(it); @@ -154,19 +154,19 @@ void CostGraph::RemoveOperator(const OperatorInfoPtr& op) { } } -bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr& op_test) { +bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { struct IsInGraph { const OperatorInfoPtr test_; - explicit IsInGraph(const OperatorInfoPtr& n) : test_(n) {} - bool operator()(const OperatorInfoPtr& in) const { return (test_ == in); } + explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {} + bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); } }; return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); } -bool CostGraph::IsEdgeInCostGraph(const std::string& test_edge_name, size_t output_index, size_t input_index) { - for (auto& edge_pair : edges_) { +bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { + for (auto &edge_pair : edges_) { auto edges = edge_pair.second; - for (auto& edge : edges) { + for (auto &edge : edges) { MS_EXCEPTION_IF_NULL(edge); bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) && (edge->next_op_input_index() == input_index); @@ -182,12 +182,12 @@ std::vector> CostGraph::ConstructConnectedComponents( std::vector alive_ops) { std::map visited; - for (auto& op : alive_ops) { + for (auto &op : alive_ops) { visited[op] = false; } MS_LOG(INFO) << "visited: " << visited.size() << "."; - for (auto& op : alive_ops) { + for (auto &op : alive_ops) { if ((!visited[op]) && op->is_alive()) { std::shared_ptr new_component = std::make_shared(); MS_EXCEPTION_IF_NULL(new_component); @@ -199,14 +199,14 @@ std::vector> CostGraph::ConstructConnectedComponents( return connected_compoents_; } -void CostGraph::DFS(const OperatorInfoPtr& current_op, std::map* visited, - const std::shared_ptr& component) { +void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component) { MS_EXCEPTION_IF_NULL(visited); MS_EXCEPTION_IF_NULL(component); visited->at(current_op) = true; component->AddOperator(current_op); - for (auto& edge : current_op->succ_edges()) { + for (auto &edge : current_op->succ_edges()) { bool bool_test = (visited->find(edge->next_operator()) != visited->end()) && (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive(); if (bool_test) { @@ -215,7 +215,7 @@ void CostGraph::DFS(const OperatorInfoPtr& current_op, std::mapprev_edges()) { + for (auto &edge : current_op->prev_edges()) { bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) && (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive(); if (bool_test) { @@ -226,14 +226,14 @@ void CostGraph::DFS(const OperatorInfoPtr& current_op, std::map v -CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::shared_ptr& e, - const OperatorInfoPtr& v) { +CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr &e, + const OperatorInfoPtr &v) { MS_EXCEPTION_IF_NULL(u); MS_EXCEPTION_IF_NULL(v); MS_EXCEPTION_IF_NULL(e); CostPtrList ret; - for (const auto& u_strategy : u->GetStrategyCost()) { - for (const auto& v_strategy : v->GetStrategyCost()) { + for (const auto &u_strategy : u->GetStrategyCost()) { + for (const auto &v_strategy : v->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(u_strategy); MS_EXCEPTION_IF_NULL(v_strategy); auto u_strategy_ptr = u_strategy->strategy_ptr; @@ -241,9 +241,9 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: CostPtrList clist1 = u_strategy->cost_list; CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr); CostPtrList clist3 = v_strategy->cost_list; - for (const auto& cost1 : clist1) { - for (const auto& cost2 : clist2) { - for (const auto& cost3 : clist3) { + for (const auto &cost1 : clist1) { + for (const auto &cost2 : clist2) { + for (const auto &cost3 : clist3) { MS_EXCEPTION_IF_NULL(cost1); MS_EXCEPTION_IF_NULL(cost2); MS_EXCEPTION_IF_NULL(cost3); @@ -274,14 +274,14 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: } // Create final cost list for the graph containing a signle node: u -CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { +CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { MS_EXCEPTION_IF_NULL(u); CostPtrList ret; - for (const auto& u_strategy : u->GetStrategyCost()) { + for (const auto &u_strategy : u->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(u_strategy); auto u_strategy_ptr = u_strategy->strategy_ptr; CostPtrList clist1 = u_strategy->cost_list; - for (const auto& cost1 : clist1) { + for (const auto &cost1 : clist1) { MS_EXCEPTION_IF_NULL(cost1); auto decision = std::make_shared(u_strategy_ptr, cost1); auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); @@ -299,16 +299,16 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { return ret; } -CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) { +CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) { CostPtrList after_mem_filter; // Filter out the valid costs - for (auto& a_cost : cost_list) { + for (auto &a_cost : cost_list) { if (a_cost->memory_with_reuse_ <= memory) { after_mem_filter.emplace_back(std::move(a_cost)); } } - std::function LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { + std::function LocalCompare = [&](CostPtr init, const CostPtr &cost_x) { MS_EXCEPTION_IF_NULL(cost_x); if (init == nullptr || cost_x->computation_cost_ < memory) { init = cost_x; @@ -319,7 +319,7 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare); } -CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) { +CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { // Select the cost with minimum training time. Currently, the training time is modeled as = // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ if (cost_list.empty()) { @@ -329,7 +329,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d CostPtrList after_mem_filter; double minimum_memory = DBL_MAX; // Filter out the valid costs. - for (auto& a_cost : cost_list) { + for (auto &a_cost : cost_list) { if (a_cost->memory_with_reuse_ <= memory) { after_mem_filter.emplace_back(std::move(a_cost)); } else if (a_cost->memory_with_reuse_ < minimum_memory) { @@ -371,7 +371,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d return ret; } -CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector& all_cost_list, +CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_cost_list, double available_memory) { CostPtrList selected_cost_list(all_cost_list.size(), nullptr); double minimum = DBL_MAX, total_memory = 0.0; @@ -418,7 +418,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect } MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; - for (auto& c : all_cost_list[k]) { + for (auto &c : all_cost_list[k]) { selected_cost_list[k] = c; recursive(k + 1); } @@ -427,7 +427,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect return ret; } -Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector& alive_ops) { +Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector &alive_ops) { MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph."; auto connected_components = ConstructConnectedComponents(alive_ops); MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; @@ -516,7 +516,7 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector alive_ops; - (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr& op) { + (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); if (op->is_alive()) { alive_ops.push_back(op); @@ -620,7 +620,7 @@ Status CostGraph::SearchStrategy() { // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated // return the v and the edge u --> v OperatorInfoPtr CostGraph::CheckOpElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1; if (bool_test) { if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) { @@ -633,21 +633,21 @@ OperatorInfoPtr CostGraph::CheckOpElimination() const { // Check the graph whether an EdgeElimination can be performed std::vector> CostGraph::CheckEdgeElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); if (!op->is_alive()) continue; - std::map count; - for (auto& edge : op->GetAliveSuccEdges()) { + std::map count; + for (auto &edge : op->GetAliveSuccEdges()) { MS_EXCEPTION_IF_NULL(edge); auto v = edge->next_operator(); count[v.get()]++; } - for (auto& pair : count) { - auto* op_ptr = pair.first; + for (auto &pair : count) { + auto *op_ptr = pair.first; int op_count = pair.second; if (op_count > 1) { std::vector> ret; - for (auto& edge : op->GetAliveSuccEdges()) { + for (auto &edge : op->GetAliveSuccEdges()) { MS_EXCEPTION_IF_NULL(edge); if (edge->next_operator().get() == op_ptr) { ret.push_back(edge); @@ -662,7 +662,7 @@ std::vector> CostGraph::CheckEdgeElimination() const { // Check the graph whether a MergeElimination can be performed OperatorInfoPtr CostGraph::CheckMergeElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1; if (bool_test) { @@ -678,7 +678,7 @@ OperatorInfoPtr CostGraph::CheckMergeElimination() const { // Check the graph whether a ContractElimination can be performed OperatorInfoPtr CostGraph::CheckContractElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty(); if (bool_test) { @@ -696,7 +696,7 @@ OperatorInfoPtr CostGraph::CheckContractElimination() const { // Check the graph whether a TriangleElimination can be performed std::pair> CostGraph::CheckTriangleElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2); if (bool_test) { @@ -707,13 +707,13 @@ std::pair> CostGraph::CheckTriangleElimin auto first_op = edge1->next_operator(); auto second_op = edge2->next_operator(); MS_EXCEPTION_IF_NULL(first_op); - for (auto& first_op_succ_edge : first_op->GetAliveSuccEdges()) { + for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) { if (first_op_succ_edge->next_operator() == second_op) { return {op, first_op_succ_edge}; } } MS_EXCEPTION_IF_NULL(second_op); - for (auto& second_op_succ_edge : second_op->GetAliveSuccEdges()) { + for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) { if (second_op_succ_edge->next_operator() == first_op) { return {op, second_op_succ_edge}; } @@ -726,7 +726,7 @@ std::pair> CostGraph::CheckTriangleElimin // Check the graph whether a StarElimination can be performed. // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. OperatorInfoPtr CostGraph::CheckStarElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1); if (bool_test) { @@ -738,7 +738,7 @@ OperatorInfoPtr CostGraph::CheckStarElimination() const { // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge. -std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr& op) { +std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr &op) { // in this case, the operators are organised in the form of u-->op-->v, and the goal // is to eliminate 'op'. MS_EXCEPTION_IF_NULL(op); @@ -786,7 +786,7 @@ std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr& op) { // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges', // and sets new costlist for the new edge. -std::shared_ptr CostGraph::EliminationEdges(const std::vector>& edges) { +std::shared_ptr CostGraph::EliminationEdges(const std::vector> &edges) { MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges."; MS_EXCEPTION_IF_NULL(edges[0]); auto u = edges[0]->prev_operator(); @@ -796,7 +796,7 @@ std::shared_ptr CostGraph::EliminationEdges(const std::vectorname() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); std::vector output_indexs, input_indexs; - for (auto& edge : edges) { + for (auto &edge : edges) { MS_EXCEPTION_IF_NULL(edge); if (edge->is_combined()) { auto from_output_indexs = edge->prev_op_output_indexs(); @@ -824,18 +824,18 @@ std::shared_ptr CostGraph::EliminationEdges(const std::vectorcomputation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; @@ -862,7 +862,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the // target_op -OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { +OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); auto target_op = op->GetAliveSuccEdges()[0]->next_operator(); auto edge_ptr = op->GetAliveSuccEdges()[0]; @@ -871,13 +871,13 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << "."; bool valid = false; - for (auto& tar_stra_cost : target_op->GetStrategyCost()) { + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(tar_stra_cost); auto tar_stra = tar_stra_cost->strategy_ptr; auto tar_clist_origin = tar_stra_cost->cost_list; CostPtrList tar_clist_new; - for (auto& op_stra_cost : op->GetStrategyCost()) { + for (auto &op_stra_cost : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_stra_cost); auto op_stra = op_stra_cost->strategy_ptr; auto op_clist = op_stra_cost->cost_list; @@ -904,17 +904,17 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' // for this contract under the strategy 'contract_op_stra' void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra, - const CostPtrList& contract_op_cost_list, - const CostPtrList& edge_cost_list, StrategyPtr target_op_stra, - const CostPtrList& tar_cost_list, CostPtrList* tar_cost_list_new) { + const CostPtrList &contract_op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr target_op_stra, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) { for (size_t i = 0; i < contract_op_cost_list.size(); ++i) { - auto& contract_op_cost = contract_op_cost_list[i]; + auto &contract_op_cost = contract_op_cost_list[i]; MS_EXCEPTION_IF_NULL(contract_op_cost); for (size_t j = 0; j < edge_cost_list.size(); ++j) { - auto& edge_cost = edge_cost_list[j]; + auto &edge_cost = edge_cost_list[j]; MS_EXCEPTION_IF_NULL(edge_cost); for (size_t k = 0; k < tar_cost_list.size(); ++k) { - auto& tar_cost = tar_cost_list[k]; + auto &tar_cost = tar_cost_list[k]; MS_EXCEPTION_IF_NULL(tar_cost); double computation = contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; @@ -941,20 +941,20 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the // target_op -OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) { +OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); auto target_op = op->GetAlivePrevEdges()[0]->prev_operator(); auto edge_ptr = op->GetAlivePrevEdges()[0]; MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << "."; bool valid = false; - for (auto& tar_stra_cost : target_op->GetStrategyCost()) { + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(tar_stra_cost); auto tar_stra = tar_stra_cost->strategy_ptr; auto tar_clist_origin = tar_stra_cost->cost_list; CostPtrList tar_clist_new; - for (auto& op_stra_cost : op->GetStrategyCost()) { + for (auto &op_stra_cost : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_stra_cost); auto op_stra = op_stra_cost->strategy_ptr; auto op_clist = op_stra_cost->cost_list; @@ -978,19 +978,19 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) { } void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, - StrategyPtr right_op_stra, const CostPtr& right_op_cost, - const CostPtrList& elimi_op_clist, - const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, - const CostPtrList& left_node_clist_origin, - CostPtrList* left_node_clist_new) { + StrategyPtr right_op_stra, const CostPtr &right_op_cost, + const CostPtrList &elimi_op_clist, + const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { MS_EXCEPTION_IF_NULL(right_edge_cost); MS_EXCEPTION_IF_NULL(right_op_cost); MS_EXCEPTION_IF_NULL(left_node_clist_new); - for (auto& elimi_op_cost : elimi_op_clist) { + for (auto &elimi_op_cost : elimi_op_clist) { MS_EXCEPTION_IF_NULL(elimi_op_cost); - for (auto& left_edge_cost : left_edge_clist) { + for (auto &left_edge_cost : left_edge_clist) { MS_EXCEPTION_IF_NULL(left_edge_cost); - for (auto& left_node_cost : left_node_clist_origin) { + for (auto &left_node_cost : left_node_clist_origin) { MS_EXCEPTION_IF_NULL(left_node_cost); double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; @@ -1015,16 +1015,16 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, } } -void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist, - const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra, - const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra, - const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, - const CostPtrList& left_node_clist_origin, - CostPtrList* left_node_clist_new) { +void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist, + const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra, + const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra, + const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { MS_EXCEPTION_IF_NULL(elimi_op); - for (auto& right_node_cost : right_node_clist) { + for (auto &right_node_cost : right_node_clist) { MS_EXCEPTION_IF_NULL(right_node_cost); - for (auto& right_edge_cost : right_edge_clist) { + for (auto &right_edge_cost : right_edge_clist) { MS_EXCEPTION_IF_NULL(right_edge_cost); CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, @@ -1033,8 +1033,8 @@ void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_o } } -OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, - const std::shared_ptr& edge_left_right) { +OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, + const std::shared_ptr &edge_left_right) { MS_EXCEPTION_IF_NULL(edge_left_right); MS_EXCEPTION_IF_NULL(elimi_op); MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << "."; @@ -1056,19 +1056,19 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, } bool valid = false; - for (auto& left_node_stra_cost : left_node->GetStrategyCost()) { + for (auto &left_node_stra_cost : left_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(left_node_stra_cost); auto left_node_stra = left_node_stra_cost->strategy_ptr; auto left_node_clist_origin = left_node_stra_cost->cost_list; CostPtrList left_node_clist_new; - for (auto& elimi_op_stra_cost : elimi_op->GetStrategyCost()) { + for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(elimi_op_stra_cost); auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr; auto elimi_op_clist = elimi_op_stra_cost->cost_list; auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra); - for (auto& right_node_stra_cost : right_node->GetStrategyCost()) { + for (auto &right_node_stra_cost : right_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(right_node_stra_cost); auto right_node_stra = right_node_stra_cost->strategy_ptr; auto right_node_clist = right_node_stra_cost->cost_list; @@ -1095,16 +1095,16 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, return left_node; } -void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_node_stra, - const CostPtrList& first_succ_node_clist, - const CostPtrList& first_succ_edge_clist, - const StrategyPtr& merged_op_stra, const CostPtrList& merged_op_clist, +void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, std::vector succ_nodes_stras, - CostPtrList& succ_edges_costs, CostPtrList& succ_nodes_costs, - CostPtrList* first_succ_node_clist_new) { - for (auto& first_succ_node_cost : first_succ_node_clist) { - for (auto& first_succ_edge_cost : first_succ_edge_clist) { - for (auto& merged_node_cost : merged_op_clist) { + CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs, + CostPtrList *first_succ_node_clist_new) { + for (auto &first_succ_node_cost : first_succ_node_clist) { + for (auto &first_succ_edge_cost : first_succ_edge_clist) { + for (auto &merged_node_cost : merged_op_clist) { MS_EXCEPTION_IF_NULL(merged_node_cost); succ_nodes_stras[0] = first_succ_node_stra; succ_edges_costs[0] = first_succ_edge_cost; @@ -1141,12 +1141,12 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n } } -void CostGraph::CreateStarEliminationCostList(std::vector>& succ_edges, - const StrategyPtr& first_succ_node_stra, - const CostPtrList& first_succ_node_clist, - const CostPtrList& first_succ_edge_clist, - const StrategyPtr& merged_op_stra, const CostPtrList& merged_op_clist, - CostPtrList* first_succ_node_clist_new) { +void CostGraph::CreateStarEliminationCostList(std::vector> &succ_edges, + const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, + CostPtrList *first_succ_node_clist_new) { std::vector succ_nodes_stras(succ_edges.size(), nullptr); CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr); std::function recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist, @@ -1167,15 +1167,15 @@ void CostGraph::CreateStarEliminationCostList(std::vector> MS_EXCEPTION_IF_NULL(succ_edge); auto succ_node = succ_edge->next_operator(); MS_EXCEPTION_IF_NULL(succ_node); - for (auto& succ_node_stra_cost : succ_node->GetStrategyCost()) { + for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(succ_node_stra_cost); auto succ_node_stra = succ_node_stra_cost->strategy_ptr; auto succ_node_clist = succ_node_stra_cost->cost_list; auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra); - for (auto& succ_node_cost : succ_node_clist) { + for (auto &succ_node_cost : succ_node_clist) { MS_EXCEPTION_IF_NULL(succ_node_cost); - for (auto& succ_edge_cost : succ_edge_clist) { + for (auto &succ_edge_cost : succ_edge_clist) { MS_EXCEPTION_IF_NULL(succ_edge_cost); succ_nodes_stras[k] = succ_node_stra; succ_edges_costs[k] = succ_edge_cost; @@ -1189,11 +1189,11 @@ void CostGraph::CreateStarEliminationCostList(std::vector> recursive(1); } -std::vector> CostGraph::EliminationStar(const OperatorInfoPtr& merged_op) { +std::vector> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) { MS_EXCEPTION_IF_NULL(merged_op); auto succ_edges = merged_op->GetAliveSuccEdges(); MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << "."; - for (auto& succ_edge : succ_edges) { + for (auto &succ_edge : succ_edges) { MS_EXCEPTION_IF_NULL(succ_edge->next_operator()); MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << "."; } @@ -1205,13 +1205,13 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo // 'merged_op' is merged into first_node MS_EXCEPTION_IF_NULL(first_succ_node); - for (auto& first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { + for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost); auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr; auto first_succ_node_clist = first_succ_node_stra_cost->cost_list; CostPtrList first_succ_node_clist_new; - for (auto& merged_op_stra_cost : merged_op->GetStrategyCost()) { + for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(merged_op_stra_cost); auto merged_op_stra = merged_op_stra_cost->strategy_ptr; auto merged_op_clist = merged_op_stra_cost->cost_list; @@ -1238,7 +1238,7 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo } Status CostGraph::InitSelectedStrategy() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); auto result = op->InitSelectedStrategy(op->selected_strategy()); if (result != SUCCESS) { @@ -1249,9 +1249,9 @@ Status CostGraph::InitSelectedStrategy() { } Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); - const auto& output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); + const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); if ((output_parameter != 0) && (output_parameter != 1)) { MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed."; return FAILED; @@ -1261,7 +1261,7 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { } Status CostGraph::CalculateOpsMemoryCost() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); if (op->CalculateMemoryCost() != SUCCESS) { MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; @@ -1272,9 +1272,9 @@ Status CostGraph::CalculateOpsMemoryCost() { } Status CostGraph::CalculateEdgesMemoryCost() { - for (auto& edge_pair : edges_) { - const auto& edges = edge_pair.second; - for (auto& one_edge : edges) { + for (auto &edge_pair : edges_) { + const auto &edges = edge_pair.second; + for (auto &one_edge : edges) { if (one_edge->CalculateMemoryCost() != SUCCESS) { MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; return FAILED; @@ -1284,7 +1284,7 @@ Status CostGraph::CalculateEdgesMemoryCost() { return SUCCESS; } -OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) const { +OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { for (auto one_op : ops_) { if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { if (one_op->refkey_parameter_name() == p_name) { @@ -1295,7 +1295,7 @@ OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) c return nullptr; } Status CostGraph::CorrectOpsMemoryCost() { - for (auto& one_op : ops_) { + for (auto &one_op : ops_) { if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { if (one_op->GetAliveSuccEdges().size() > 1) { // Filter out the case when the TmpIdentity being used by multiple operators diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index e701a377b97..530f67ba453 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -70,7 +70,7 @@ class CostGraph { costmodel_beta_ = DEFAULT_COST_MODEL_BETA; } ~CostGraph() = default; - void AddOperator(const OperatorInfoPtr& op) { ops_.push_back(op); } + void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } OperatorInfoPtr FindOperatorByIndex(size_t index) { if (index >= ops_.size()) { MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << "."; @@ -78,29 +78,29 @@ class CostGraph { } return ops_[index]; } - void RemoveOperator(const OperatorInfoPtr& op); - bool IsOperatorInCostGraph(const OperatorInfoPtr& op); + void RemoveOperator(const OperatorInfoPtr &op); + bool IsOperatorInCostGraph(const OperatorInfoPtr &op); // the edge is in the form: u --> v - void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr& edge) { + void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { std::vector curr_edges(edges_[{u_node, v_node}]); curr_edges.push_back(edge); edges_[{u_node, v_node}] = curr_edges; } // An edge is uniquely identified by its name, and its output index and input index. - bool IsEdgeInCostGraph(const std::string&, size_t, size_t); + bool IsEdgeInCostGraph(const std::string &, size_t, size_t); void SetDeviceMemoryAndCostParameter(); std::vector> ConstructConnectedComponents(std::vector); - void DFS(const OperatorInfoPtr& current_op, std::map* visited, - const std::shared_ptr& component); + void DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component); - CostPtrList CreateFinalCostList(const OperatorInfoPtr& u, const EdgePtr& e, const OperatorInfoPtr& v); - CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr& u); - CostPtr SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory); - CostPtr SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory); - CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector& all_costlist, double memory); - Status SearchStrategyForMultiNodeFinalGraph(const std::vector&); + CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); + CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); + CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory); + CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); + CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); + Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); std::vector> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { return edges_[{u_node, v_node}]; } @@ -145,36 +145,36 @@ class CostGraph { */ OperatorInfoPtr CheckStarElimination() const; // Applying Operator Elimination in DP algorithm - EdgePtr EliminationOp(const OperatorInfoPtr& op); + EdgePtr EliminationOp(const OperatorInfoPtr &op); // Applying Edge Elimination in DP algorithm - EdgePtr EliminationEdges(const std::vector& edges); + EdgePtr EliminationEdges(const std::vector &edges); // Applying Merge Elimination in DP algorithm - OperatorInfoPtr EliminationMerge(const OperatorInfoPtr& op); - void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& op_cost_list, - const CostPtrList& edge_cost_list, StrategyPtr tar_op_strategy, - const CostPtrList& tar_cost_list, CostPtrList* tar_cost_list_new); + OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op); + void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new); // Applying Contract Elimination in DP algorithm - OperatorInfoPtr EliminationContract(const OperatorInfoPtr& op); - void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList&, const CostPtrList&, StrategyPtr, - const CostPtrList&, CostPtrList*); + OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op); + void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr, + const CostPtrList &, CostPtrList *); // Applying Triangle Elimination in DP algorithm. return the left_node - OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr& elimi_op, const EdgePtr& edge_left_right); - void CreateTriangleEliminationCostList(const OperatorInfoPtr&, const CostPtrList&, const CostPtrList&, - const StrategyPtr&, const StrategyPtr&, const StrategyPtr&, const CostPtrList&, - const CostPtrList&, const CostPtrList&, CostPtrList*); + OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right); + void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const StrategyPtr &, const StrategyPtr &, + const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *); // Given the relevant costlist, create the TriangleElimination cost - void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&, const CostPtrList&, - const CostPtrList&, const CostPtr&, const CostPtrList&, CostPtrList*); + void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &, + const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *); // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - std::vector EliminationStar(const OperatorInfoPtr& op); - void CreateStarEliminationCostList(std::vector&, const StrategyPtr&, const CostPtrList&, const CostPtrList&, - const StrategyPtr&, const CostPtrList&, CostPtrList*); - void CreateStarEliminationSubCostList(const StrategyPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&, - const CostPtrList&, std::vector, CostPtrList&, CostPtrList&, - CostPtrList*); + std::vector EliminationStar(const OperatorInfoPtr &op); + void CreateStarEliminationCostList(std::vector &, const StrategyPtr &, const CostPtrList &, + const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *); + void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const CostPtrList &, std::vector, + CostPtrList &, CostPtrList &, CostPtrList *); // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then // the memory cost can be resused. Status CalculateOpsMemoryCost(); @@ -186,16 +186,16 @@ class CostGraph { std::vector GetOperators() const { return ops_; } size_t GetNumPairs() const { return edges_.size(); } Status InitSelectedStrategy(); - OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const; + OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only // once (instead of multiple times), this method is used to correct this. Status CorrectOpsMemoryCost(); // Needed by rec_parser - void add_inputs_tensor_name(const std::vector& inputs_tensor_name) { + void add_inputs_tensor_name(const std::vector &inputs_tensor_name) { inputs_tensor_name_list_.push_back(inputs_tensor_name); } const std::vector> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; } - void add_tuple_getitem(const std::pair& tuple_getitem) { + void add_tuple_getitem(const std::pair &tuple_getitem) { auto ret = tuple_getitem_list_.insert(tuple_getitem); if (ret.second == false) { MS_LOG(EXCEPTION) << "The insert item is already exist."; diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 0192dce8b89..8ad8b46f32b 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -23,22 +23,22 @@ namespace mindspore { namespace parallel { -void OperatorCost::set_is_parameter(const std::vector& is_parameter) { is_parameter_ = is_parameter; } +void OperatorCost::set_is_parameter(const std::vector &is_parameter) { is_parameter_ = is_parameter; } -void OperatorCost::set_is_parameter_involve(const std::vector& is_parameter_inv) { +void OperatorCost::set_is_parameter_involve(const std::vector &is_parameter_inv) { is_parameter_involve_ = is_parameter_inv; } void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } -void OperatorCost::SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths) { +void OperatorCost::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { inputs_type_lengths_ = input_lengths; outputs_type_lengths_ = output_lengths; } -double OperatorCost::GetMemoryCost(const std::vector& inputs, - const std::vector& outputs) const { +double OperatorCost::GetMemoryCost(const std::vector &inputs, + const std::vector &outputs) const { double result = 0.0; if (output_parameter_involve_ == 1) { // When this operator has multiple outputs, they all contributes to the memory. @@ -64,7 +64,7 @@ double OperatorCost::GetMemoryCost(const std::vector& inputs, } // return the per device communication cost in the forward phase. -double MatMulCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, +double MatMulCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t) const { TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -80,7 +80,7 @@ double MatMulCost::GetForwardCommCost(const std::vector& inputs, con } // return the per device communication cost in the forward phase. -double MatMulCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double MatMulCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not // fully utilize all devices @@ -107,8 +107,8 @@ double MatMulCost::GetBackwardCommCost(const std::vector& inputs, co // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double MatMulCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t) const { +double MatMulCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) double result = 0.0; TensorInfo output0 = outputs[0]; @@ -126,7 +126,7 @@ double MatMulCost::GetForwardComputationCost(const std::vector& inpu // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double MatMulCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, +double MatMulCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; @@ -151,14 +151,14 @@ double MatMulCost::GetBackwardComputationCost(const std::vector& inp } // Return the per device communication cost in the forward phase. -double ActivationCost::GetForwardCommCost(const std::vector&, const std::vector&, +double ActivationCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // ReLU is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // Return the per device communication cost in the backward phase. -double ActivationCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ActivationCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -180,7 +180,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector& inputs // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { TensorInfo input0_info = inputs[0]; Shape input0_slice_shape = input0_info.slice_shape(); @@ -189,19 +189,20 @@ double ActivationCost::GetForwardComputationCost(const std::vector& // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double ActivationCost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } // Return the per device communication cost in the forward phase. -double SoftmaxCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double SoftmaxCost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // In the forward phase, the communication cost = 0 return 0.0; } // Return the per device communication cost in the backward phase. -double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double SoftmaxCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -223,7 +224,7 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, c // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double SoftmaxCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In the forward phase, the computation cost = slice(A) TensorInfo input0 = inputs[0]; @@ -233,46 +234,47 @@ double SoftmaxCost::GetForwardComputationCost(const std::vector& inp // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double TmpIdentityCost::GetForwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // Identity is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double TmpIdentityCost::GetBackwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // Identity is the element-wise operator, thus it does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double TmpIdentityCost::GetForwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetForwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double TmpIdentityCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, + int32_t) const { return 0.0; } // Return the per device PEAK memory cost contributed by this operator in a training iteration. -double TmpIdentityCost::GetMemoryCost(const std::vector&, const std::vector&) const { +double TmpIdentityCost::GetMemoryCost(const std::vector &, const std::vector &) const { return 0.0; } -double BatchParallelCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector&, +double BatchParallelCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t) const { double cost = 0.0; for (size_t i = 0; i < inputs.size(); ++i) { @@ -281,13 +283,13 @@ double BatchParallelCost::GetForwardComputationCost(const std::vector&, - const std::vector&, +double BatchParallelCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } -double BatchParallelCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double BatchParallelCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -313,13 +315,13 @@ double BatchParallelCost::GetBackwardCommCost(const std::vector& inp return result; } // return the per device communication cost in the forward phase. -double PReLUCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double PReLUCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // prelu does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double PReLUCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double PReLUCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[1]) { @@ -342,7 +344,7 @@ double PReLUCost::GetBackwardCommCost(const std::vector& inputs, con // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double PReLUCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double PReLUCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -354,8 +356,8 @@ double PReLUCost::GetForwardComputationCost(const std::vector& input // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double PReLUCost::GetBackwardComputationCost(const std::vector& inputs, - const std::vector&, +double PReLUCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t stage_id) const { // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; @@ -380,20 +382,21 @@ double PReLUCost::GetBackwardComputationCost(const std::vector&, const std::vector&, int32_t) const { +double OneHotCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // onehot does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double OneHotCost::GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double OneHotCost::GetBackwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // onehot does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double OneHotCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double OneHotCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In onehot's forward phase, the computation cost = slice(A) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -402,29 +405,29 @@ double OneHotCost::GetForwardComputationCost(const std::vector& inpu // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double OneHotCost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double OneHotCost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); @@ -435,13 +438,13 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double ReshapeCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, +double ReshapeCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -457,7 +460,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector& inputs, co } // return the per device communication cost in the backward phase. -double ReshapeCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ReshapeCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -479,8 +482,8 @@ double ReshapeCost::GetBackwardCommCost(const std::vector& inputs, c // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ReshapeCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReshapeCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); @@ -496,12 +499,12 @@ double ReshapeCost::GetForwardComputationCost(const std::vector& inp // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double ReshapeCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double ReshapeCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } -double ArithmeticCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double ArithmeticCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { double result; result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + @@ -509,8 +512,8 @@ double ArithmeticCost::GetForwardComputationCost(const std::vector& return result; } -double ArithmeticCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, - int32_t stage_id) const { +double ArithmeticCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -544,7 +547,7 @@ double ArithmeticCost::GetBackwardComputationCost(const std::vector& return result; } -double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -580,7 +583,7 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs return result; } -bool IsDataParallel(const Shape& shape, const Shape& slice_shape, int32_t stage_id) { +bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int32_t stage_id) { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); @@ -589,8 +592,8 @@ bool IsDataParallel(const Shape& shape, const Shape& slice_shape, int32_t stage_ return (total_device_num == IntToSize(strategy0)); } -double ReduceMethodCost::GetForwardCommCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMethodCost::GetForwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -611,7 +614,7 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector& input return result; } -double ReduceMethodCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ReduceMethodCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -634,8 +637,8 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector& inpu return result; } -double ReduceMethodCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMethodCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -656,8 +659,8 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector return result; } -double ReduceMeanCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMeanCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -678,7 +681,7 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector& return result; } -double DropOutCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double DropOutCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { if (inputs.empty()) { return 0.0; @@ -689,13 +692,14 @@ double DropOutCost::GetForwardComputationCost(const std::vector& inp } // return the per device communication cost in the forward phase. -double GatherV2Cost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double GatherV2Cost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // GatherV2Cost does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double GatherV2Cost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double GatherV2Cost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -721,7 +725,7 @@ double GatherV2Cost::GetBackwardCommCost(const std::vector& inputs, return result; } -double GatherV2Cost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double GatherV2Cost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -731,12 +735,12 @@ double GatherV2Cost::GetForwardComputationCost(const std::vector& in return result; } -double GatherV2Cost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double GatherV2Cost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } -double LayerNormCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double LayerNormCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_.size() != inputs.size()) { @@ -769,7 +773,7 @@ double LayerNormCost::GetBackwardCommCost(const std::vector& inputs, return result; } -double LayerNormCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double LayerNormCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { double result = 0.0; if (inputs_type_lengths_.size() != inputs.size()) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 37b054aa98c..a243f8adfae 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -63,31 +63,31 @@ class OperatorCost { } virtual ~OperatorCost() = default; - void set_is_parameter(const std::vector& is_parameter); - void set_is_parameter_involve(const std::vector&); + void set_is_parameter(const std::vector &is_parameter); + void set_is_parameter_involve(const std::vector &); void set_output_parameter_involve(int); - void SetInputAndOutputTypeLength(const std::vector& input_lengths, const std::vector& output_lengths); + void SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); std::vector inputs_type_lengths() const { return inputs_type_lengths_; } std::vector outputs_type_lengths() const { return outputs_type_lengths_; } // per device communication cost - virtual double GetCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; // per device computation cost - virtual double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const = 0; - virtual double GetBackwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const = 0; + virtual double GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; + virtual double GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; // per device PEAK memory cost in a training iteration // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), // plus necessary inputs. - virtual double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const; + virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; protected: // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of @@ -113,23 +113,23 @@ class MatMulCost : public OperatorCost { ~MatMulCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using MatMulCostPtr = std::shared_ptr; @@ -140,21 +140,21 @@ class ActivationCost : public OperatorCost { ActivationCost() : OperatorCost(false) {} ~ActivationCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ActivationCostPtr = std::shared_ptr; @@ -167,21 +167,21 @@ class SoftmaxCost : public OperatorCost { SoftmaxCost() : OperatorCost(false) {} ~SoftmaxCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t) const override; }; using SoftmaxCostPtr = std::shared_ptr; @@ -192,24 +192,24 @@ class TmpIdentityCost : public OperatorCost { TmpIdentityCost() : OperatorCost(false) {} ~TmpIdentityCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override; + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override; }; using TmpIdentityCostPtr = std::shared_ptr; @@ -219,21 +219,21 @@ class BatchParallelCost : public OperatorCost { BatchParallelCost() : OperatorCost(false) {} ~BatchParallelCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using BatchParallelCostPtr = std::shared_ptr; @@ -244,30 +244,30 @@ class VirtualDatasetCost : public OperatorCost { VirtualDatasetCost() : OperatorCost(false) {} ~VirtualDatasetCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override { + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override { return 0.0; } }; @@ -279,27 +279,27 @@ class GeneratorBaseCost : public OperatorCost { GeneratorBaseCost() : OperatorCost(false) {} ~GeneratorBaseCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -313,23 +313,23 @@ class PReLUCost : public OperatorCost { ~PReLUCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using PReLUCostPtr = std::shared_ptr; @@ -341,23 +341,23 @@ class OneHotCost : public OperatorCost { ~OneHotCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using OneHotCostPtr = std::shared_ptr; @@ -369,23 +369,23 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { ~SoftmaxCrossEntropyWithLogitsCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; @@ -398,27 +398,27 @@ class ReshapeCost : public OperatorCost { ~ReshapeCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ReshapeCostPtr = std::shared_ptr; @@ -429,22 +429,22 @@ class ArithmeticCost : public OperatorCost { ArithmeticCost() : OperatorCost(false) {} ~ArithmeticCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ArithmeticCostPtr = std::shared_ptr; @@ -457,21 +457,21 @@ class ReduceMethodCost : public OperatorCost { ReduceMethodCost() : OperatorCost(true) {} ~ReduceMethodCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -488,7 +488,7 @@ class ReduceMeanCost : public ReduceMethodCost { ReduceMeanCost() : ReduceMethodCost(true) {} ~ReduceMeanCost() override = default; - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ReduceMeanCostPtr = std::shared_ptr; @@ -499,27 +499,27 @@ class GetNextCost : public OperatorCost { GetNextCost() : OperatorCost(false) {} ~GetNextCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -532,23 +532,23 @@ class DropOutCost : public OperatorCost { DropOutCost() : OperatorCost(true) {} ~DropOutCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -562,21 +562,21 @@ class LayerNormCost : public OperatorCost { LayerNormCost() : OperatorCost(true) {} ~LayerNormCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -590,21 +590,21 @@ class GatherV2Cost : public OperatorCost { GatherV2Cost() : OperatorCost(true) {} ~GatherV2Cost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t) const override; }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 44d3642b9ca..6b438cb6703 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -35,7 +35,7 @@ const TensorParam MakeTensor(int n, int c, int h, int w) { new_tensor.tensor_shape.shape_c = c; new_tensor.tensor_shape.shape_h = h; new_tensor.tensor_shape.shape_w = w; - const TensorParam& tensor = new_tensor; + const TensorParam &tensor = new_tensor; return tensor; } @@ -71,7 +71,7 @@ Graph::NodeType MakeNewOperator(std::vector> ops, return NewOp; } -TensorParam Fill2DTensor(const std::vector>& ops, const size_t iter_ops, +TensorParam Fill2DTensor(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor) { if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { auto attrs = ops[iter_ops]->attrs(); @@ -94,7 +94,7 @@ TensorParam Fill2DTensor(const std::vector>& ops, return NewTensor.tensor_parm; } -OperatorRec CompleteOperatorInputs(const std::vector>& ops, const size_t iter_ops, +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor) { for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); iter_input_tensors++) { @@ -118,7 +118,7 @@ OperatorRec CompleteOperatorInputs(const std::vector>& ops, const size_t iter_ops, +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, const size_t iter_input_tensors, Graph::NodeType NewTensor) { if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { auto attrs = ops[iter_ops]->attrs(); @@ -145,8 +145,8 @@ TensorParam Complete2DInputs(const std::vector>& o return NewTensor.apply.arguments[iter_input_tensors]; } -std::shared_ptr ParseGraph(const std::vector>& ops, - const std::vector>& input_tensor_names) { +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names) { std::shared_ptr graph(new Graph); if (ops.size() > SIZE_MAX / 2) { MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; @@ -161,7 +161,7 @@ std::shared_ptr ParseGraph(const std::vector>& input_tensor_names, std::shared_ptr graph) { +void MakeEdge(const std::vector> &input_tensor_names, std::shared_ptr graph) { for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); @@ -173,8 +173,8 @@ void MakeEdge(const std::vector>& input_tensor_names, s } } -size_t GetIndexInInputTensorNames(const std::vector>& input_tensor_name, - const std::string& input_name) { +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, + const std::string &input_name) { for (size_t index = 0; index < input_tensor_name.size(); index++) { if (input_tensor_name[index][0] == input_name) { return index; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 0d719c33d88..ae50ced418c 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -45,22 +45,22 @@ const TensorParam MakeTensor(int n, int c, int h, int w); Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops); -TensorParam Fill2DTensor(const std::vector>& ops, const size_t iter_ops, +TensorParam Fill2DTensor(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor); -OperatorRec CompleteOperatorInputs(const std::vector>& ops, const size_t iter_ops, +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor); -TensorParam Complete2DInputs(const std::vector>& ops, const size_t iter_ops, +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, const size_t iter_input_tensor, Graph::NodeType NewTensor); -std::shared_ptr ParseGraph(const std::vector>& ops, - const std::vector>& input_tensor_names); +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names); -void MakeEdge(const std::vector>& input_tensor_names, std::shared_ptr graph); +void MakeEdge(const std::vector> &input_tensor_names, std::shared_ptr graph); -size_t GetIndexInInputTensorNames(const std::vector>& input_tensor_names, - const std::string& input_name); +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, + const std::string &input_name); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index ab216cb22cf..bc4aca896ba 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -73,11 +73,11 @@ void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_bef void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } -void ParallelContext::set_communication_backend(const std::string& communication_backend) { +void ParallelContext::set_communication_backend(const std::string &communication_backend) { communication_backend_ = communication_backend; } -bool ParallelContext::set_parallel_mode(const std::string& parallel_mode) { +bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); if (iter == PARALLEL_MODE_LIST.end()) { MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode; @@ -87,7 +87,7 @@ bool ParallelContext::set_parallel_mode(const std::string& parallel_mode) { return true; } -bool ParallelContext::set_strategy_search_mode(const std::string& strategy_search_mode) { +bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) { auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode); if (iter == STRATEGY_SEARCH_MODE_LIST.end()) { MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode; diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index 265f5bac715..64261cb964a 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -40,8 +40,8 @@ constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; class ParallelContext { public: ~ParallelContext() = default; - ParallelContext(const ParallelContext&) = delete; - ParallelContext& operator=(const ParallelContext&) = delete; + ParallelContext(const ParallelContext &) = delete; + ParallelContext &operator=(const ParallelContext &) = delete; static std::shared_ptr GetInstance(); @@ -60,13 +60,13 @@ class ParallelContext { void set_global_rank(int32_t global_rank); int32_t global_rank() const { return global_rank_; } - void set_communication_backend(const std::string& communication_backend); + void set_communication_backend(const std::string &communication_backend); std::string communication_backend() const { return communication_backend_; } - bool set_parallel_mode(const std::string& parallel_mode); + bool set_parallel_mode(const std::string ¶llel_mode); std::string parallel_mode() const { return parallel_mode_; } - bool set_strategy_search_mode(const std::string& strategy_search_mode); + bool set_strategy_search_mode(const std::string &strategy_search_mode); std::string strategy_search_mode() const { return strategy_search_mode_; } void set_parameter_broadcast(bool parameter_broadcast); diff --git a/mindspore/ccsrc/parallel/costmodel_context.h b/mindspore/ccsrc/parallel/costmodel_context.h index 23c9f7cc8d1..99374830517 100644 --- a/mindspore/ccsrc/parallel/costmodel_context.h +++ b/mindspore/ccsrc/parallel/costmodel_context.h @@ -28,8 +28,8 @@ namespace parallel { class CostModelContext { public: ~CostModelContext() = default; - CostModelContext(const CostModelContext&) = delete; - CostModelContext& operator=(const CostModelContext&) = delete; + CostModelContext(const CostModelContext &) = delete; + CostModelContext &operator=(const CostModelContext &) = delete; void ResetCostModel(); void ResetAlgoParameters(); diff --git a/mindspore/ccsrc/parallel/device_manager.cc b/mindspore/ccsrc/parallel/device_manager.cc index 0b34cedc006..45628bec650 100644 --- a/mindspore/ccsrc/parallel/device_manager.cc +++ b/mindspore/ccsrc/parallel/device_manager.cc @@ -30,15 +30,15 @@ namespace mindspore { namespace parallel { DeviceManagerPtr g_device_manager = nullptr; -Stage::Stage(const std::vector& devices, int num, int rank) +Stage::Stage(const std::vector &devices, int num, int rank) : devices_(devices), number_(num), rank_(rank) { gm_ = GroupManager(); } // NOTE: '-1' indicates ERROR -int Stage::global_rank(Group* g) const { return ((g == nullptr) ? rank_ : -1); } +int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string& backend) { +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { if (device_num <= 0) { MS_LOG(ERROR) << "'device_num' must be positive."; return false; @@ -87,7 +87,7 @@ void CheckGlobalDeviceManager() { } } -int32_t GetListMemberByIndex(size_t index, const RankList& devices) { +int32_t GetListMemberByIndex(size_t index, const RankList &devices) { size_t i = 0; int32_t result = 0; if ((devices.empty()) || (index >= devices.size())) { @@ -104,7 +104,7 @@ int32_t GetListMemberByIndex(size_t index, const RankList& devices) { return result; } -std::shared_ptr GetListMemberByIndex(size_t index, const std::vector>& device_list) { +std::shared_ptr GetListMemberByIndex(size_t index, const std::vector> &device_list) { size_t i = 0; std::shared_ptr result; if ((device_list.empty()) || (index >= device_list.size())) { @@ -123,8 +123,8 @@ std::shared_ptr GetListMemberByIndex(size_t index, const std::vector DeviceManager::GetStageById(int32_t stage_id) { return res; } int32_t index = 0; - for (auto& stage : stages_) { + for (auto &stage : stages_) { if (index == stage_id) return stage; index++; } @@ -224,7 +224,7 @@ RankList DeviceManager::GetDeviceListByStageId(int32_t stage_id) const { << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); RankList res; int32_t index = 0; - for (auto& stage : stage_devices_) { + for (auto &stage : stage_devices_) { if (index == stage_id) { return stage; } @@ -280,19 +280,19 @@ Device DeviceManager::CreateNewDeviceByRank(int32_t rank) const { return Device( std::vector DeviceManager::CreateDeviceListByRankList(RankList ranks) { std::vector dev_list; - for (auto& rank : ranks) { + for (auto &rank : ranks) { Device one = CreateNewDeviceByRank(rank); dev_list.push_back(one); } return dev_list; } -DeviceManager& DeviceManager::GetInstance() { +DeviceManager &DeviceManager::GetInstance() { static DeviceManager instance = DeviceManager(); return instance; } -std::string DeviceManager::FindRankListNameByHashName(const std::string& hash_name) { +std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) { std::string tmp = "WORLD_GROUP"; if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) { return tmp; @@ -305,7 +305,7 @@ std::string DeviceManager::FindRankListNameByHashName(const std::string& hash_na return iter->second; } -std::string HashName(const std::string& origin_name) { return std::to_string(std::hash{}(origin_name)); } +std::string HashName(const std::string &origin_name) { return std::to_string(std::hash{}(origin_name)); } // Group name is generated using the increasing ranks of the devices. // E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name @@ -343,8 +343,8 @@ std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) { // Create the group with the given devices and the given name. The GroupManager // gm_ will create a new group only if there does not exit a group with the same // name. Otherwise, let the pointer g point to that group. -Group DeviceManager::CreateGroup(const std::string& group_name, - const std::vector& devices) { +Group DeviceManager::CreateGroup(const std::string &group_name, + const std::vector &devices) { if ((world_group() == NCCL_WORLD_GROUP) && (devices.size() != devices_.size())) { MS_LOG(EXCEPTION) << "Do not support sub group for nccl"; } @@ -354,7 +354,7 @@ Group DeviceManager::CreateGroup(const std::string& group_name, } // Create the group with only the given devices' ranks. -Group DeviceManager::CreateGroup(const RankList& dev_ranks) { +Group DeviceManager::CreateGroup(const RankList &dev_ranks) { std::unordered_set rank_set(dev_ranks.begin(), dev_ranks.end()); if (dev_ranks.size() != rank_set.size()) { MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list"; diff --git a/mindspore/ccsrc/parallel/device_manager.h b/mindspore/ccsrc/parallel/device_manager.h index e87c1d740f3..3afafe6a9c2 100644 --- a/mindspore/ccsrc/parallel/device_manager.h +++ b/mindspore/ccsrc/parallel/device_manager.h @@ -53,13 +53,13 @@ class Stage { explicit Stage(std::vector devices) : devices_(std::move(devices)), number_(0), rank_(0) { gm_ = GroupManager(); } - Stage(const std::vector& devices, int num, int rank); + Stage(const std::vector &devices, int num, int rank); ~Stage() = default; int GetStageNum() const { return number_; } size_t GetDevicesNum() const { return devices_.size(); } std::vector GetDevicesList() { return devices_; } - int global_rank(Group* g) const; + int global_rank(Group *g) const; private: std::vector devices_; @@ -70,11 +70,11 @@ class Stage { // This method is used for initializing the global DeviceManager 'g_device_manager', // arguments including 'device_num' and 'global_rank' -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string& backend); +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); void CheckGlobalDeviceManager(); -std::string HashName(const std::string& rank_list_name); +std::string HashName(const std::string &rank_list_name); class DeviceManager { // This class is used to manage the abstract devices, including group-related and stage-related management. @@ -82,9 +82,9 @@ class DeviceManager { DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(0) { gm_ = GroupManager(); } ~DeviceManager() = default; - Status Init(const RankList& devices, int32_t local_device, const RankList& stage_map, const std::string& backend); + Status Init(const RankList &devices, int32_t local_device, const RankList &stage_map, const std::string &backend); - static DeviceManager& GetInstance(); + static DeviceManager &GetInstance(); RankList GetDeviceListByStageId(int32_t stage_id) const; RankList global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const; @@ -92,8 +92,8 @@ class DeviceManager { std::vector CreateDeviceListByRankList(RankList ranks); std::string GenerateGroupNameByRanks(RankList dev_ranks); - Group CreateGroup(const std::string& group_name, const std::vector& devices); - Group CreateGroup(const RankList& dev_ranks); + Group CreateGroup(const std::string &group_name, const std::vector &devices); + Group CreateGroup(const RankList &dev_ranks); std::shared_ptr GetStageById(int32_t stage_id); size_t DeviceNum() const { return devices_.size(); } @@ -105,7 +105,7 @@ class DeviceManager { void set_global_rank(int32_t global_rank) { global_rank_ = global_rank; } void Clear(); std::string world_group() const { return gm_.world_group(); } - std::string FindRankListNameByHashName(const std::string& hash_name); + std::string FindRankListNameByHashName(const std::string &hash_name); private: std::vector> devices_; diff --git a/mindspore/ccsrc/parallel/device_matrix.cc b/mindspore/ccsrc/parallel/device_matrix.cc index 3fdc3dd15a9..3c9467a2239 100644 --- a/mindspore/ccsrc/parallel/device_matrix.cc +++ b/mindspore/ccsrc/parallel/device_matrix.cc @@ -53,7 +53,7 @@ Status DeviceMatrix::CreateGroupList() { return Status::SUCCESS; } -Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) { +Status DeviceMatrix::GetDevicesAlongDim(const uint32_t &dim, RankList *devices) { if (dim >= dev_shape_.size()) { MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!"; } @@ -78,7 +78,7 @@ Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) for (int32_t i = 0; i < step; i++) { local_group_list.push_back(group); - (void)std::for_each(group.begin(), group.end(), [](int32_t& a) { a++; }); + (void)std::for_each(group.begin(), group.end(), [](int32_t &a) { a++; }); } // higher than dim @@ -88,19 +88,19 @@ Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) // search rank int32_t target = rank_; for (int32_t i = 0; i < len; i++) { - for (RankList& temp : local_group_list) { + for (RankList &temp : local_group_list) { if (std::any_of(temp.begin(), temp.end(), [target](int32_t a) { return a == target; })) { *devices = temp; return Status::SUCCESS; } - (void)std::for_each(temp.begin(), temp.end(), [step](int32_t& a) { a = a + step; }); + (void)std::for_each(temp.begin(), temp.end(), [step](int32_t &a) { a = a + step; }); } } MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!"; return Status::FAILED; } -Shape ConvertRankToCoordinate(int32_t rank, const Shape& dev_shape) { +Shape ConvertRankToCoordinate(int32_t rank, const Shape &dev_shape) { Shape dev_coordinate; for (size_t i = 0; i < dev_shape.size(); ++i) { int32_t size = dev_shape[dev_shape.size() - i - 1]; @@ -115,8 +115,8 @@ Shape ConvertRankToCoordinate(int32_t rank, const Shape& dev_shape) { return dev_coordinate; } -Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* rank_list) { - for (auto& element : tensor_map) { +Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) { + for (auto &element : tensor_map) { // -1 means the corresponding dimension is not split. if (element == MAP_NONE) { continue; @@ -127,10 +127,10 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* ra } Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); - for (auto& tmp_rank : dev_list_) { + for (auto &tmp_rank : dev_list_) { Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); bool matched = true; - for (auto& map : tensor_map) { + for (auto &map : tensor_map) { if (map == MAP_NONE) { continue; } @@ -148,7 +148,7 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* ra return SUCCESS; } -std::string ShapeToString(const Shape& shape) { +std::string ShapeToString(const Shape &shape) { std::string str = "["; for (size_t i = 0; i < shape.size(); ++i) { str += std::to_string(shape[i]); @@ -159,9 +159,9 @@ std::string ShapeToString(const Shape& shape) { return str + "]"; } -std::string ListToString(const std::vector& list) { +std::string ListToString(const std::vector &list) { std::string str = "["; - for (auto& element : list) { + for (auto &element : list) { str += std::to_string(element) + ", "; } return str + "]"; diff --git a/mindspore/ccsrc/parallel/device_matrix.h b/mindspore/ccsrc/parallel/device_matrix.h index a9120006041..236a7fad087 100644 --- a/mindspore/ccsrc/parallel/device_matrix.h +++ b/mindspore/ccsrc/parallel/device_matrix.h @@ -37,8 +37,8 @@ class DeviceMatrix { ~DeviceMatrix() = default; std::vector group_list() const { return group_list_; } Status CreateGroupList(); - Status GetDevicesByTensorMap(const Shape& tensor_map, RankList* rank_list); - Status GetDevicesAlongDim(const uint32_t& dim, RankList* devices); + Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list); + Status GetDevicesAlongDim(const uint32_t &dim, RankList *devices); private: int32_t rank_ = -1; @@ -48,8 +48,8 @@ class DeviceMatrix { std::vector group_list_; }; -std::string ShapeToString(const Shape& shape); -std::string ListToString(const std::vector& list); +std::string ShapeToString(const Shape &shape); +std::string ListToString(const std::vector &list); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index bad947687d4..42ba42cf8a4 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -28,28 +28,28 @@ namespace mindspore { namespace parallel { #define REGISTER(className) \ - OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs& attrs) { \ + OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ return std::make_shared(name, in, out, attrs); \ } \ RegisterAction className##Register(#className, (CreatFn)objectCreator##className); -typedef OperatorInfoPtr (*CreatFn)(const std::string& name, const Shapes& shape_in, const Shapes shape_out, - const PrimitiveAttrs& attrs); +typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, + const PrimitiveAttrs &attrs); class DynCreator { public: ~DynCreator() = default; // creat static singleton dyn_creator instance - static DynCreator& Instance() { + static DynCreator &Instance() { static DynCreator fac = DynCreator(); return fac; } // register void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } // creator - OperatorInfoPtr Creat(const std::string& name, const Shapes& shape_in, const Shapes& shape_out, - const PrimitiveAttrs& attrs, size_t count) { + OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, + const PrimitiveAttrs &attrs, size_t count) { std::string op_name = name + std::to_string(count); auto iter = Function_map_.find(name); if (iter == Function_map_.end()) { @@ -66,7 +66,7 @@ class DynCreator { class RegisterAction { public: - RegisterAction(const std::string& name, CreatFn creatfn) : name_(name) { + RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { DynCreator::Instance().Regist(name, creatfn); } ~RegisterAction() = default; diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc index 43df9fe8026..f5f0fe85cb4 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc @@ -25,7 +25,7 @@ using mindspore::tensor::Tensor; namespace mindspore { namespace parallel { -std::string GetOpPythonPath(const OperatorName& op_name) { +std::string GetOpPythonPath(const OperatorName &op_name) { // almost all ops are defined in two main paths const std::string ops_module = OP_PATH; py::module mod = py::module::import(common::SafeCStr(ops_module)); @@ -35,7 +35,7 @@ std::string GetOpPythonPath(const OperatorName& op_name) { return ops_module; } -ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name, const std::string& instance_name) { +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { std::string op_path = GetOpPythonPath(op_name); py::module mod = py::module::import(common::SafeCStr(op_path)); if (!py::hasattr(mod, common::SafeCStr(op_name))) { @@ -44,7 +44,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name } std::vector arg_list; (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), - [](const Attr& attr) { return ValuePtrToPyData(attr.second); }); + [](const Attr &attr) { return ValuePtrToPyData(attr.second); }); py::object obj = parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list); ValuePtr op_instance = nullptr; @@ -56,7 +56,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name return op_instance; } -AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr& value_ptr) { +AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) { auto value_node = NewValueNode(value_ptr); MS_EXCEPTION_IF_NULL(value_node); return value_node->cast(); @@ -85,7 +85,7 @@ AnfNodePtr CreatInt32Imm(int32_t value) { return ValuePtrToAnfNodePtr(value_ptr); } -std::string GetInstanceNameByCNode(const CNodePtr& cnode) { +std::string GetInstanceNameByCNode(const CNodePtr &cnode) { PrimitivePtr prim = GetValueNode(cnode->input(0)); if (!prim) { MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr."; @@ -94,7 +94,7 @@ std::string GetInstanceNameByCNode(const CNodePtr& cnode) { return HashInstanceName(instance_name); } -std::string HashInstanceName(const std::string& name) { +std::string HashInstanceName(const std::string &name) { auto using_hash_name = common::GetEnv(USING_HASH_NAME); std::string instance_name; if ((using_hash_name.empty()) || (using_hash_name == "on")) { @@ -105,7 +105,7 @@ std::string HashInstanceName(const std::string& name) { return instance_name; } -Status GenerateGraph::Init(const CNodePtr& cnode) { +Status GenerateGraph::Init(const CNodePtr &cnode) { if (!cnode) { MS_LOG(ERROR) << "Init:cnode is nullptr"; return FAILED; @@ -133,7 +133,7 @@ Status GenerateGraph::Init(const CNodePtr& cnode) { return SUCCESS; } -AnfNodePtr GenerateGraph::PushBack(const std::vector& inputs) { +AnfNodePtr GenerateGraph::PushBack(const std::vector &inputs) { CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode MS_EXCEPTION_IF_NULL(cnode); cnode->set_scope(scope_); @@ -146,7 +146,7 @@ AnfNodePtr GenerateGraph::PushBack(const std::vector& inputs) { return new_anf_node_ptr; } -AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name, const OperatorAttrs& attrs) { +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) { name_idx_++; ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_)); if (pyop_instance == nullptr) { @@ -156,7 +156,7 @@ AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name, const OperatorA return value_node->cast(); } -AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name) { +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) { name_idx_++; OperatorAttrs attrs; ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_)); diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/parallel/graph_util/generate_graph.h index c829e67b6a9..d5535c7dc20 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.h +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.h @@ -33,25 +33,25 @@ namespace mindspore { namespace parallel { #define USING_HASH_NAME "USING_HASH_NAME" // Get the operator's path where the operator has be defined -std::string GetOpPythonPath(const OperatorName& op_name); +std::string GetOpPythonPath(const OperatorName &op_name); // Init python operator Instance -ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name, const std::string& instance_name); +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name); AnfNodePtr CreatTypeInt(int32_t value); AnfNodePtr CreatInt32Imm(int32_t value); AnfNodePtr CreateInt32Tensor(int32_t value); -std::string HashInstanceName(const std::string& name); +std::string HashInstanceName(const std::string &name); class GenerateGraph { public: GenerateGraph() : name_idx_(0) {} - Status Init(const CNodePtr& cnode); + Status Init(const CNodePtr &cnode); ~GenerateGraph() = default; AnfNodePtr virtual_input_node() { return virtual_input_node_; } - AnfNodePtr NewOpInst(const OperatorName& op_name, const OperatorAttrs& attrs); - AnfNodePtr NewOpInst(const OperatorName& op_name); - AnfNodePtr PushBack(const std::vector& inputs); + AnfNodePtr NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs); + AnfNodePtr NewOpInst(const OperatorName &op_name); + AnfNodePtr PushBack(const std::vector &inputs); private: CNodePtr cnode_; diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc index 3006cb76801..cbffc10e701 100644 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc @@ -29,7 +29,7 @@ namespace mindspore { namespace parallel { -py::dict GetParameterLayout(const FuncGraphPtr& graph) { +py::dict GetParameterLayout(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; std::vector graph_params = graph->parameters(); @@ -50,7 +50,7 @@ py::dict GetParameterLayout(const FuncGraphPtr& graph) { return dict; } -py::dict GetCNodeStrategy(const FuncGraphPtr& graph) { +py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; auto ret = graph->get_return(); @@ -75,7 +75,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr& graph) { return dict; } -py::dict GetAllreduceFusion(const FuncGraphPtr& graph) { +py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE); diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h index 78f597b2135..e21b81a557b 100644 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h +++ b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h @@ -23,9 +23,9 @@ namespace mindspore { namespace parallel { -py::dict GetParameterLayout(const FuncGraphPtr& graph); -py::dict GetCNodeStrategy(const FuncGraphPtr& graph); -py::dict GetAllreduceFusion(const FuncGraphPtr& graph); +py::dict GetParameterLayout(const FuncGraphPtr &graph); +py::dict GetCNodeStrategy(const FuncGraphPtr &graph); +py::dict GetAllreduceFusion(const FuncGraphPtr &graph); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/parallel/graph_util/graph_info.cc index 46c9a37960b..175413c0fd7 100644 --- a/mindspore/ccsrc/parallel/graph_util/graph_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/graph_info.cc @@ -24,12 +24,12 @@ namespace mindspore { namespace parallel { -std::vector FindPrimtive(const FuncGraphPtr& graph, const std::string& name) { +std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name) { AnfNodePtr ret = graph->get_return(); MS_EXCEPTION_IF_NULL(ret); std::vector all_nodes = DeepScopedGraphSearch(ret); std::vector prim_list; - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { if (!IsValueNode(node)) { continue; } @@ -44,7 +44,7 @@ std::vector FindPrimtive(const FuncGraphPtr& graph, const std::str return prim_list; } -void DumpGraph(const FuncGraphPtr& root, const std::string& name) { +void DumpGraph(const FuncGraphPtr &root, const std::string &name) { if (MsContext::GetInstance()->save_graphs_flag()) { draw::Draw(name + ".dot", root); DumpIR(name + ".ir", root); diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.h b/mindspore/ccsrc/parallel/graph_util/graph_info.h index 96deab29064..de800f09812 100644 --- a/mindspore/ccsrc/parallel/graph_util/graph_info.h +++ b/mindspore/ccsrc/parallel/graph_util/graph_info.h @@ -24,8 +24,8 @@ namespace mindspore { namespace parallel { -std::vector FindPrimtive(const FuncGraphPtr& graph, const std::string& name); -void DumpGraph(const FuncGraphPtr& root, const std::string& name); +std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name); +void DumpGraph(const FuncGraphPtr &root, const std::string &name); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.cc b/mindspore/ccsrc/parallel/graph_util/node_info.cc index b2ce8ba432b..c085d712400 100644 --- a/mindspore/ccsrc/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/node_info.cc @@ -23,13 +23,13 @@ namespace mindspore { namespace parallel { -std::string ParameterName(const AnfNodePtr& node_ptr) { +std::string ParameterName(const AnfNodePtr &node_ptr) { auto para_ptr = node_ptr->cast(); MS_EXCEPTION_IF_NULL(para_ptr); return para_ptr->name(); } -bool ParameterRequireGrad(const AnfNodePtr& node_ptr) { +bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { auto para_ptr = node_ptr->cast(); if (para_ptr == nullptr) { return false; diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.h b/mindspore/ccsrc/parallel/graph_util/node_info.h index f4f46d2149f..bda268e582f 100644 --- a/mindspore/ccsrc/parallel/graph_util/node_info.h +++ b/mindspore/ccsrc/parallel/graph_util/node_info.h @@ -22,9 +22,9 @@ namespace mindspore { namespace parallel { -std::string ParameterName(const AnfNodePtr& node_ptr); +std::string ParameterName(const AnfNodePtr &node_ptr); -bool ParameterRequireGrad(const AnfNodePtr& node_ptr); +bool ParameterRequireGrad(const AnfNodePtr &node_ptr); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/group_manager.h b/mindspore/ccsrc/parallel/group_manager.h index 430d2f64ed7..f763d483ccb 100644 --- a/mindspore/ccsrc/parallel/group_manager.h +++ b/mindspore/ccsrc/parallel/group_manager.h @@ -37,11 +37,11 @@ class Group { public: Group(); ~Group() = default; - Status Init(const std::string& name, const std::vector& devices); + Status Init(const std::string &name, const std::vector &devices); std::vector GetDevicesList() const; std::string name() const { return name_; } bool IsInThisGroup(int32_t device_rank); - Status GetIndex(size_t* index); + Status GetIndex(size_t *index); size_t GetDevNum() const { return devices_.size(); } private: @@ -54,14 +54,14 @@ class GroupManager { GroupManager(); ~GroupManager() = default; - Status CreateGroup(const std::string& name, const std::vector& devices, Group* group); - Status DestroyGroup(Group* group); + Status CreateGroup(const std::string &name, const std::vector &devices, Group *group); + Status DestroyGroup(Group *group); Status DestroyAllGroups(); - Status GetRankID(const std::string& name, unsigned int* rank_id); - Status GetRankSize(const std::string& name, unsigned int* rank_size); - Status FindGroup(const std::string& name, Group** group); + Status GetRankID(const std::string &name, unsigned int *rank_id); + Status GetRankSize(const std::string &name, unsigned int *rank_size); + Status FindGroup(const std::string &name, Group **group); std::string world_group() const { return world_group_; } - void set_world_group(const std::string& name) { world_group_ = name; } + void set_world_group(const std::string &name) { world_group_ = name; } void Clear(); private: diff --git a/mindspore/ccsrc/parallel/node_check.cc b/mindspore/ccsrc/parallel/node_check.cc index e43d03c29cb..7fecd307c78 100644 --- a/mindspore/ccsrc/parallel/node_check.cc +++ b/mindspore/ccsrc/parallel/node_check.cc @@ -80,7 +80,7 @@ const std::set BLACK_LIST = {TUPLE_GETITEM, REF_TO_EMBED, STOP_GRADIENT}; -bool IsInBlackList(const PrimitivePtr& prim) { +bool IsInBlackList(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); } diff --git a/mindspore/ccsrc/parallel/node_check.h b/mindspore/ccsrc/parallel/node_check.h index 6e5db37069d..8b628f31b12 100644 --- a/mindspore/ccsrc/parallel/node_check.h +++ b/mindspore/ccsrc/parallel/node_check.h @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -bool IsInBlackList(const PrimitivePtr& prim); +bool IsInBlackList(const PrimitivePtr &prim); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/parallel/ops_info/activation_info.cc index e659759de27..6bc33677a6a 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -Status Activation::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -41,7 +41,7 @@ Status Activation::SetCostUnderStrategy(const StrategyPtr& strategy) { return SUCCESS; } -Status Activation::CheckStrategy(const StrategyPtr& strategy) { +Status Activation::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -110,7 +110,7 @@ Status Activation::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; @@ -120,7 +120,7 @@ Status Activation::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status Softmax::CheckStrategy(const StrategyPtr& strategy) { +Status Softmax::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -133,7 +133,7 @@ Status Softmax::CheckStrategy(const StrategyPtr& strategy) { std::vector stra = strategy->GetInputDim(); Dimensions input_strategy = stra.at(0); - for (auto& element : axis_) { + for (auto &element : axis_) { int32_t axis_index = element; if (element < 0) { size_t input_dim = inputs_shape_.at(0).size(); @@ -176,7 +176,7 @@ Status Softmax::GetAttrs() { } std::vector value_vector = value_tuple->value(); (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_), - [](const ValuePtr& value) { return static_cast(GetValue(value)); }); + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); if (axis_.empty()) { MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; return FAILED; @@ -205,7 +205,7 @@ Status Softmax::GetAttrs() { return SUCCESS; } -Status Softmax::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -231,7 +231,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) { is_auto_parallel_ = true; Shape input0_split; (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); - for (auto& element : axis_) { + for (auto &element : axis_) { int32_t axis_index = element; if (element < 0) { size_t input_dim = inputs_shape_.at(0).size(); @@ -247,7 +247,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; @@ -334,7 +334,7 @@ Status ActivationBase::InferTensorInfo() { return SUCCESS; } -Status ActivationBase::Init(const StrategyPtr& strategy) { +Status ActivationBase::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -344,7 +344,7 @@ Status ActivationBase::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status ActivationBase::InitForCostModel(const StrategyPtr& strategy) { +Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -547,7 +547,7 @@ Status ExpandDimsInfo::InferMirrorOps() { return SUCCESS; } -Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { +Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) { std::vector axis; auto axis_list = value_tuple->value(); if (inputs_shape_.empty()) { @@ -568,7 +568,7 @@ Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { } // convert negative axis to positive. - for (auto& dim : axis_list) { + for (auto &dim : axis_list) { if (!dim->isa()) { MS_LOG(ERROR) << name_ << ": The type of axis is not int"; return FAILED; @@ -595,7 +595,7 @@ Status SqueezeInfo::GetAttrs() { return SUCCESS; } -Status SqueezeInfo::InferReplaceOps(const StrategyPtr& strategy) { +Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { Attr attr = std::make_pair(AXIS, axis_); OperatorAttrs attrs = {attr}; OperatorParams params; @@ -689,7 +689,7 @@ Status SqueezeInfo::InferTensorInfo() { return SUCCESS; } -Status SqueezeInfo::Init(const StrategyPtr& strategy) { +Status SqueezeInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; } diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index 887be5ea33b..a71c6b6df75 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -31,13 +31,13 @@ namespace mindspore { namespace parallel { class ActivationBase : public OperatorInfo { public: - ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} ~ActivationBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; protected: Status InferMirrorOps() override; @@ -49,21 +49,21 @@ class ActivationBase : public OperatorInfo { class Activation : public ActivationBase { public: - Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~Activation() override = default; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; }; class ActivationInfo : public Activation { public: - ActivationInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~ActivationInfo() override = default; @@ -73,8 +73,8 @@ class ActivationInfo : public Activation { class ActivationOther : public Activation { public: - ActivationOther(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~ActivationOther() override = default; @@ -84,31 +84,31 @@ class ActivationOther : public Activation { class GeluInfo : public ActivationOther { public: - GeluInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~GeluInfo() override = default; }; class TanhInfo : public ActivationOther { public: - TanhInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~TanhInfo() override = default; }; class Softmax : public ActivationBase { public: - explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~Softmax() override = default; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; private: @@ -117,32 +117,32 @@ class Softmax : public ActivationBase { class SoftmaxInfo : public Softmax { public: - SoftmaxInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Softmax(name, inputs_shape, outputs_shape, attrs) {} ~SoftmaxInfo() override = default; }; class LogSoftmaxInfo : public Softmax { public: - LogSoftmaxInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Softmax(name, inputs_shape, outputs_shape, attrs) {} ~LogSoftmaxInfo() override = default; }; class ReLUInfo : public ActivationOther { public: - ReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ReLUInfo() override = default; }; class CastInfo : public ActivationOther { public: - CastInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~CastInfo() override = default; @@ -152,23 +152,23 @@ class CastInfo : public ActivationOther { class SqrtInfo : public ActivationOther { public: - SqrtInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SqrtInfo() override = default; }; class NegInfo : public ActivationOther { public: - NegInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~NegInfo() override = default; }; class ExpandDimsInfo : public ActivationOther { public: - ExpandDimsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ExpandDimsInfo() override = default; @@ -187,18 +187,18 @@ class ExpandDimsInfo : public ActivationOther { class SqueezeInfo : public ActivationOther { public: - SqueezeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SqueezeInfo() override = default; protected: - Status InferAxis(const ValueTuplePtr& value_tuple); + Status InferAxis(const ValueTuplePtr &value_tuple); Status GetAttrs() override; - Status InferReplaceOps(const StrategyPtr& strategy); + Status InferReplaceOps(const StrategyPtr &strategy); Status InferTensorMap() override; Status InferTensorInfo() override; - Status Init(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; private: ValueTuplePtr axis_; @@ -206,8 +206,8 @@ class SqueezeInfo : public ActivationOther { class SquareInfo : public ActivationOther { public: - SquareInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SquareInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h index 78dfc238037..27caacc30cd 100644 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h @@ -31,92 +31,92 @@ namespace mindspore { namespace parallel { class ArithmeticBase : public OperatorInfo { public: - ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} ~ArithmeticBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; void ReComputeBatchSplitFlagList() override; protected: Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); Shapes InferExpendShape(); }; class SubInfo : public ArithmeticBase { public: - SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SubInfo() override = default; }; class TensorAddInfo : public ArithmeticBase { public: - TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TensorAddInfo() override = default; }; class MulInfo : public ArithmeticBase { public: - MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MulInfo() override = default; }; class DivInfo : public ArithmeticBase { public: - DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~DivInfo() override = default; }; class RealDivInfo : public ArithmeticBase { public: - RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~RealDivInfo() override = default; }; class FloorDivInfo : public ArithmeticBase { public: - FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~FloorDivInfo() override = default; }; class PowInfo : public ArithmeticBase { public: - PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~PowInfo() override = default; }; class GreaterInfo : public ArithmeticBase { public: - GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~GreaterInfo() override = default; }; class AssignSubInfo : public ArithmeticBase { public: - AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~AssignSubInfo() override = default; }; @@ -124,8 +124,8 @@ class AssignSubInfo : public ArithmeticBase { // All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { public: - SigmoidCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SigmoidCrossEntropyWithLogitsInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc index 9d356cd573c..dac3b0a6759 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status BatchParallelInfo::CheckStrategy(const StrategyPtr& strategy) { +Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -161,7 +161,7 @@ Status BatchParallelInfo::InferTensorInfo() { Status BatchParallelInfo::GetAttrs() { return SUCCESS; } -Status BatchParallelInfo::Init(const StrategyPtr& strategy) { +Status BatchParallelInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -170,7 +170,7 @@ Status BatchParallelInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status BatchParallelInfo::InitForCostModel(const StrategyPtr& strategy) { +Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -184,7 +184,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h index 4cedb9b7b82..db6cb206d51 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h @@ -29,22 +29,22 @@ namespace mindspore { namespace parallel { class BatchParallelInfo : public OperatorInfo { public: - BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} - BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), dev_num_(1) {} ~BatchParallelInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; @@ -60,8 +60,8 @@ class BatchParallelInfo : public OperatorInfo { class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { public: - SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, - const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, + const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; void ReComputeBatchSplitFlagList() override; diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h index e792858338b..37f555a258c 100644 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h +++ b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h @@ -32,26 +32,26 @@ namespace mindspore { namespace parallel { class BiasAddInfo : public OperatorInfo { public: - BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~BiasAddInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; void ReComputeBatchSplitFlagList() override; protected: Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h index 9ea496e0b02..8dd2976b049 100644 --- a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h @@ -30,32 +30,32 @@ namespace mindspore { namespace parallel { class EqualInfo : public ArithmeticBase { public: - EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~EqualInfo() override = default; }; class NotEqualInfo : public ArithmeticBase { public: - NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~NotEqualInfo() override = default; }; class MaximumInfo : public ArithmeticBase { public: - MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MaximumInfo() override = default; }; class MinimumInfo : public ArithmeticBase { public: - MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MinimumInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc index c755cc785d5..87b8d15cca0 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc @@ -32,7 +32,7 @@ namespace mindspore { namespace parallel { static int32_t SEED_NUM = 1; -Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null"; return FAILED; @@ -129,7 +129,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() { return SUCCESS; } -Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -159,7 +159,7 @@ Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; @@ -178,7 +178,7 @@ std::shared_ptr>> DropoutDoMaskInfo::GenerateBa return std::make_shared>>(strategy_v); } -Status DropoutDoMaskInfo::Init(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -188,7 +188,7 @@ Status DropoutDoMaskInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -202,7 +202,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr& cnode) { +PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; @@ -237,7 +237,7 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr& cnode) { // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. -Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr& cnode) { +Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); MS_EXCEPTION_IF_NULL(prim); diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h index 3b154bd6db0..c0d112f52d4 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h @@ -31,20 +31,20 @@ namespace mindspore { namespace parallel { class DropoutDoMaskInfo : public OperatorInfo { public: - DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~DropoutDoMaskInfo() override = default; - Status Init(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; - Operator GetDropoutGenMaskReplaceOp(const CNodePtr& cnode); + Operator GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorMap() override; diff --git a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h index 84b8030f37a..2172c5cd89f 100644 --- a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h @@ -29,37 +29,37 @@ namespace mindspore { namespace parallel { class ExpInfo : public ActivationOther { public: - ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ExpInfo() override = default; }; class LogInfo : public ActivationOther { public: - LogInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~LogInfo() override = default; }; class CosInfo : public ActivationOther { public: - CosInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~CosInfo() override = default; }; class ACosInfo : public ActivationOther { public: - ACosInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ACosInfo() override = default; }; class LogicalNotInfo : public ActivationOther { public: - LogicalNotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~LogicalNotInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc index c3159918495..c9e8835f352 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc @@ -70,7 +70,7 @@ Status GatherV2Info::GetAttrs() { return SUCCESS; } -Status GatherV2Info::CheckStrategy(const StrategyPtr& strategy) { +Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " << inputs_shape_.size(); @@ -256,7 +256,7 @@ Status GatherV2Info::InferTensorSubOps() { return SUCCESS; } -Status GatherV2Info::Init(const StrategyPtr& strategy) { +Status GatherV2Info::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -270,7 +270,7 @@ Status GatherV2Info::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status GatherV2Info::InitForCostModel(const StrategyPtr& strategy) { +Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -301,7 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; @@ -311,7 +311,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h index 773d46f4294..f7aeb6a0d9f 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h @@ -38,22 +38,22 @@ constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; // If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. class GatherV2Info : public OperatorInfo { public: - GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), axis_(-1), index_size_(0), axis_strategy_(1) {} ~GatherV2Info() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc index ac9acff41b0..29d519fda8a 100644 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc @@ -39,7 +39,7 @@ Status GetNextInfo::InferTensorMap() { return SUCCESS; } -Status GetNextInfo::InferTensorLayout(TensorLayouts* outputs_layout) { +Status GetNextInfo::InferTensorLayout(TensorLayouts *outputs_layout) { if (outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << " : The layout is null."; return FAILED; @@ -96,7 +96,7 @@ Status GetNextInfo::InferDevMatrixShape() { return SUCCESS; } -Status GetNextInfo::Init(const StrategyPtr& strategy) { +Status GetNextInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed"; return FAILED; @@ -109,7 +109,7 @@ Status GetNextInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status GetNextInfo::CheckStrategy(const StrategyPtr& strategy) { +Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { std::vector stras = strategy->GetInputDim(); for (Dimensions stra : stras) { if (stra.size() != 0) { @@ -135,7 +135,7 @@ Status GetNextInfo::GetAttrTypes() { auto iter_cast = iter->second->cast(); MS_EXCEPTION_IF_NULL(iter_cast); auto types = iter_cast->value(); - for (auto& type : types) { + for (auto &type : types) { MS_EXCEPTION_IF_NULL(type); types_.push_back(type->ToString()); } @@ -143,7 +143,7 @@ Status GetNextInfo::GetAttrTypes() { auto iter_cast = iter->second->cast(); MS_EXCEPTION_IF_NULL(iter_cast); auto types = iter_cast->value(); - for (auto& type : types) { + for (auto &type : types) { MS_EXCEPTION_IF_NULL(type); types_.push_back(type->ToString()); } @@ -189,7 +189,7 @@ Status GetNextInfo::GetAttrs() { return SUCCESS; } -Status GetNextInfo::InferReplaceOps(const StrategyPtr&) { +Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { Shapes out_shapes = outputs_shape_; for (size_t i = 0; i < out_shapes.size(); ++i) { if (dev_num_ <= 0) { @@ -214,7 +214,7 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr&) { return SUCCESS; } -Status GetNextInfo::InitForCostModel(const StrategyPtr& strategy) { +Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -227,7 +227,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc index 2955f765063..8716997d9f9 100644 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status L2NormalizeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -111,7 +111,7 @@ Status L2NormalizeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h index 22ed5a965b3..ca063d01d8a 100644 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h +++ b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h @@ -31,8 +31,8 @@ namespace mindspore { namespace parallel { class L2NormalizeInfo : public Activation { public: - L2NormalizeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~L2NormalizeInfo() override = default; Status GenerateStrategies(int32_t stage_id) override; @@ -40,7 +40,7 @@ class L2NormalizeInfo : public Activation { protected: Status GetAttrs() override; Status InferMirrorOps() override; - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; private: int32_t axis_ = 0; // Default value = 0 diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h index c52645ade22..50117b81853 100644 --- a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h +++ b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h @@ -38,20 +38,20 @@ constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis"; // arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add. class LayerNormInfo : public OperatorInfo { public: - LayerNormInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), begin_norm_axis_(0) {} ~LayerNormInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; protected: Status GetAttrs() override; - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; @@ -61,7 +61,7 @@ class LayerNormInfo : public OperatorInfo { Status CreateTensorMap(size_t input_index); Status CreateTensorInfo(size_t input_index); Status CreateMirrorOp(size_t input_index); - Status GenerateGammaAndBetaStrategies(const std::vector& sp_vector); + Status GenerateGammaAndBetaStrategies(const std::vector &sp_vector); Status InitShapes(); private: diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/parallel/ops_info/loss_info.cc index 28ea19f1202..0ba325c0cd5 100644 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/loss_info.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -152,7 +152,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -162,7 +162,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -205,7 +205,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; @@ -216,7 +216,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { PrintStrategy(strategy); if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.h b/mindspore/ccsrc/parallel/ops_info/loss_info.h index 44fe22ce906..2679c2d62b4 100644 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.h +++ b/mindspore/ccsrc/parallel/ops_info/loss_info.h @@ -34,20 +34,20 @@ namespace parallel { // output_0 : [a], output_1: [a, b] class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { public: - SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SoftmaxCrossEntropyWithLogitsInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; void ReComputeBatchSplitFlagList() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc index 8d1264482b1..3f55efb66c7 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc @@ -31,8 +31,8 @@ namespace mindspore { namespace parallel { -void SetDevMatrixShape(const Dimensions& mat_a_strategy, const Dimensions& mat_b_strategy, bool transpose_b, - Shape* dev_matrix_shape) { +void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b, + Shape *dev_matrix_shape) { MS_EXCEPTION_IF_NULL(dev_matrix_shape); size_t mat_a_size = mat_a_strategy.size(); size_t mat_b_size = mat_b_strategy.size(); @@ -105,7 +105,7 @@ Status MatMulBase::GetAttrs() { return SUCCESS; } -Status CheckRelevantDimension(const Dimensions& long_strategy, const Dimensions& short_strategy) { +Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) { size_t long_size = long_strategy.size(); size_t short_size = short_strategy.size(); if (long_size < short_size) { @@ -126,7 +126,7 @@ Status CheckRelevantDimension(const Dimensions& long_strategy, const Dimensions& return SUCCESS; } -Status MatMul::CheckStrategy(const StrategyPtr& strategy) { +Status MatMul::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -239,7 +239,7 @@ Status MatMulBase::InferForwardCommunication() { } // dev_matrix_shape: [a, b, c, d, e], then output strategy: [a, b, c, e]; -Dimensions GetOutputStrategy(const Shape& dev_matrix_shape, int32_t repeated_calculation_num) { +Dimensions GetOutputStrategy(const Shape &dev_matrix_shape, int32_t repeated_calculation_num) { Dimensions output_strategy = dev_matrix_shape; if (repeated_calculation_num > 1) { // move the first dimension(repeated_calc_num_) @@ -301,7 +301,7 @@ Status MatMulBase::InferTensorMap() { return SUCCESS; } -Status MatMulBase::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { TensorLayout mat_a_layout, mat_b_layout, output_layout; if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || @@ -353,7 +353,7 @@ Status MatMulBase::InferTensorInfo() { return SUCCESS; } -Status MatMulBase::Init(const StrategyPtr& strategy) { +Status MatMulBase::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -363,7 +363,7 @@ Status MatMulBase::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status MatMulBase::InitForCostModel(const StrategyPtr& strategy) { +Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -377,7 +377,7 @@ Status MatMulBase::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape* const input) { +Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) { if (input->size() < 2) { MS_LOG(ERROR) << name_ << " : The size of inputs small than 2."; return FAILED; @@ -463,7 +463,7 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) { Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, mindspore::parallel::StrategyPtr* const sp) { + size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); if (!FULLY_USE_DEVICES) { if (IntToSize(product) > dev_num) { @@ -519,7 +519,7 @@ Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, return SUCCESS; } -void MatMulBase::InitTensorInfoForCost(std::vector* relica_inputs_tensor_vector) { +void MatMulBase::InitTensorInfoForCost(std::vector *relica_inputs_tensor_vector) { TensorLayout tly; if (transpose_a_) { Shape replica_input0_shape(inputs_tensor_info_[0].shape()); @@ -560,7 +560,7 @@ Status MatMulBase::CheckForTensorSliceValid() const { if (inputs_tensor_info_.empty()) { return FAILED; } - for (auto& one_input_tensor : inputs_tensor_info_) { + for (auto &one_input_tensor : inputs_tensor_info_) { auto slice_shape = one_input_tensor.slice_shape(); if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { @@ -570,7 +570,7 @@ Status MatMulBase::CheckForTensorSliceValid() const { return SUCCESS; } -Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (InitForCostModel(strategy) == FAILED) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h index 8a64fb7206f..86a74f78f26 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.h @@ -32,21 +32,21 @@ namespace mindspore { namespace parallel { class MatMulBase : public OperatorInfo { public: - MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MatMulBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; // Generate all strategies and the corresponding cost for this MatMul operator Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, StrategyPtr* sp); + size_t input1_shape_size, StrategyPtr *sp); - Status SwapLastTwoElements(Shape* shape); + Status SwapLastTwoElements(Shape *shape); protected: Status InferMirrorOps() override; @@ -54,8 +54,8 @@ class MatMulBase : public OperatorInfo { Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); - void InitTensorInfoForCost(std::vector*); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + void InitTensorInfoForCost(std::vector *); Status CheckForTensorSliceValid() const; Status GetAttrs() override; @@ -67,26 +67,26 @@ class MatMulBase : public OperatorInfo { class MatMul : public MatMulBase { public: - MatMul(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + MatMul(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : MatMulBase(name, inputs_shape, outputs_shape, attrs) {} ~MatMul() override = default; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; }; class MatMulInfo : public MatMul { public: - MatMulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : MatMul(name, inputs_shape, outputs_shape, attrs) {} ~MatMulInfo() override = default; }; class BatchMatMulInfo : public MatMul { public: - BatchMatMulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BatchMatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : MatMul(name, inputs_shape, outputs_shape, attrs) {} ~BatchMatMulInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc index e07609d3c4d..2c06a1ace94 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc @@ -54,7 +54,7 @@ Status OneHotInfo::GetAttrs() { return SUCCESS; } -Status OneHotInfo::CheckStrategy(const StrategyPtr& strategy) { +Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { if (inputs_shape_.size() != 3) { MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); return FAILED; @@ -185,7 +185,7 @@ Status OneHotInfo::ExtractInputInfo() { return SUCCESS; } -Status OneHotInfo::ComputeReplaceGraph(const CNodePtr& cnode) { +Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { if (dev_matrix_shape_.back() == 1) { replace_graph_ = nullptr; return SUCCESS; @@ -222,7 +222,7 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr& cnode) { return SUCCESS; } -ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr& cnode) { +ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { if (ComputeReplaceGraph(cnode) != SUCCESS) { MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; return nullptr; @@ -230,7 +230,7 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr& cnode) { return replace_graph_; } -Status OneHotInfo::Init(const StrategyPtr& strategy) { +Status OneHotInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -244,7 +244,7 @@ Status OneHotInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status OneHotInfo::InitForCostModel(const StrategyPtr& strategy) { +Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -276,7 +276,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) { } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; @@ -287,7 +287,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/parallel/ops_info/onehot_info.h index a4f00ea0936..3c8a64f9542 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.h @@ -31,20 +31,20 @@ namespace mindspore { namespace parallel { class OneHotInfo : public OperatorInfo { public: - OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~OneHotInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - ReplaceGraphPtr replace_graph(const CNodePtr& cnode) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; std::shared_ptr>> GenerateBatchStrategies() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } @@ -54,7 +54,7 @@ class OneHotInfo : public OperatorInfo { Status ExtractInputInfo(); private: - Status ComputeReplaceGraph(const CNodePtr& cnode); + Status ComputeReplaceGraph(const CNodePtr &cnode); int axis_ = -1; int32_t rank_ = 0; diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index c6115a9fa69..8074f2a32ef 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -35,7 +35,7 @@ namespace mindspore { namespace parallel { -Status CheckStrategyValue(const StrategyPtr& strategy, const Shapes& inputs_shape, bool is_auto_parallel) { +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { if (strategy == nullptr) { MS_LOG(ERROR) << "The strategy is null."; return FAILED; @@ -190,7 +190,7 @@ Operator CreateVirtualDivOp(int32_t div_num) { } // use for forward all reduce -Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& group) { +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) { OperatorName operator_name = ALL_REDUCE; ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM ValuePtr attr1_value = MakeValue(group); // group @@ -209,7 +209,7 @@ Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& grou } // use for get tensor slice -Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout) { +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { Shape tensor_map = tensor_layout.tensor_map().array(); Shape dev_matrix_shape = tensor_layout.device_arrangement().array(); OperatorName operator_name = GET_TENSOR_SLICE; @@ -228,7 +228,7 @@ Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout) { return op; } -OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num) { +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { if ((dev_num == 0) || (dev_num == 1)) { MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num; } @@ -260,7 +260,7 @@ OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num) { return op_for_weight; } -Status OperatorInfo::CreateGroupByTensorMap(const Shape& tensor_map, std::vector* group) { +Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group) { if (group == nullptr) { MS_LOG(ERROR) << "The group is null."; return FAILED; @@ -283,7 +283,7 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape& tensor_map, std::vector return SUCCESS; } -Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector* group) { +Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { if (group == nullptr) { MS_LOG(ERROR) << "The group is null."; return FAILED; @@ -306,7 +306,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector* group) { return SUCCESS; } -Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy) { +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { Shape slice_shape; if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; @@ -318,7 +318,7 @@ Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy) { return slice_shape; } -Status InferSliceShapeByStrategy(const Strategys& strategys, const Shapes& shapes, Shapes* slice_shapes) { +Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) { if (slice_shapes == nullptr) { MS_LOG(ERROR) << "The slice_shapes is null."; return FAILED; @@ -357,8 +357,8 @@ Status InferSliceShapeByStrategy(const Strategys& strategys, const Shapes& shape return SUCCESS; } -Status OperatorInfo::InferSliceShape(const Strategys& inputs_strategy, const Strategys& outputs_strategy, - Shapes* inputs_slice_shape, Shapes* outputs_slice_shape) { +Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) { if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) { MS_LOG(ERROR) << "The slice_shape is null."; return FAILED; @@ -379,7 +379,7 @@ Status OperatorInfo::InferSliceShape(const Strategys& inputs_strategy, const Str } // method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1 -Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -437,7 +437,7 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strat } // method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape -Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -485,7 +485,7 @@ Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr& str return SUCCESS; } -Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -513,7 +513,7 @@ Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr& strategy) { return SUCCESS; } -Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -543,12 +543,12 @@ Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr& strategy) { std::vector> OperatorInfo::GetAliveSuccEdges() { std::vector> ret; - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { ret.push_back(edge); } } - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos)) { ret.push_back(edge); } @@ -558,7 +558,7 @@ std::vector> OperatorInfo::GetAliveSuccEdges() { std::vector> OperatorInfo::GetAlivePrevEdges() { std::vector> ret; - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator()->is_alive()) { ret.push_back(edge); } @@ -566,12 +566,12 @@ std::vector> OperatorInfo::GetAlivePrevEdges() { return ret; } -void OperatorInfo::ReplacePreEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null."; return; } - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator() == op) { edge = new_edge; return; @@ -580,12 +580,12 @@ void OperatorInfo::ReplacePreEdge(const std::shared_ptr& op, const MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; } -void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null."; return; } - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if (edge->next_operator() == op) { edge = new_edge; return; @@ -594,13 +594,13 @@ void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr& op, cons MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; } -void OperatorInfo::ReplacePreEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null."; return; } std::vector> new_pre_edges; - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator() != op) { new_pre_edges.push_back(edge); } @@ -609,13 +609,13 @@ void OperatorInfo::ReplacePreEdges(const std::shared_ptr& op, cons prev_edges_ = new_pre_edges; } -void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null"; return; } std::vector> new_succ_edges; - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if (edge->next_operator() != op) { new_succ_edges.push_back(edge); } @@ -625,7 +625,7 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr& op, con } std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes& shapes, const std::vector& split_flag_list) { + const Shapes &shapes, const std::vector &split_flag_list) { if (shapes.size() != split_flag_list.size()) { MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " << shapes.size(); @@ -665,14 +665,14 @@ void OperatorInfo::ComputeBatchSplitFlagList() { } // This is a common method for checking whether the generated stragegy has the correct number of devuces. -Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes& inputs_partitions, StrategyPtr* const sp) { +Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { if (sp == nullptr) { MS_LOG(ERROR) << "The strategy is null."; return FAILED; } int32_t product = 1; - for (auto& input_partition : inputs_partitions) { + for (auto &input_partition : inputs_partitions) { product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); } if (!FULLY_USE_DEVICES) { @@ -694,7 +694,7 @@ std::shared_ptr>> OperatorInfo::GenerateBatchSt return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); } -void PrintStrategy(const StrategyPtr& strategy) { +void PrintStrategy(const StrategyPtr &strategy) { if (strategy == nullptr) { return; } @@ -716,8 +716,8 @@ void PrintStrategy(const StrategyPtr& strategy) { } // generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d]) -Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* const sp_vector) { +Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -740,7 +740,7 @@ Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& input return FAILED; } - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { sp->ExpandInputDimFromOneToTwo(); } @@ -749,8 +749,8 @@ Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& input // generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast // such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -770,7 +770,7 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs } // second, get the correct strategy for input0 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { std::vector tmp_strategy; Dimensions input0_strategy = sp->GetInputDim()[0]; size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); @@ -798,8 +798,8 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs // generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast // such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -819,7 +819,7 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& input } // second, get the correct strategy for input1 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { std::vector tmp_strategy; tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 @@ -848,8 +848,8 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& input // generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast // such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -881,7 +881,7 @@ Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs } // step3: reset the strategy if the dimension is 1 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { Dimensions input0_strategy = sp->GetInputDim()[0]; Dimensions input1_strategy = sp->GetInputDim()[1]; for (size_t i = 0; i < inputs_shape[0].size(); ++i) { @@ -904,9 +904,9 @@ Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs // dimension is splittable. 'inputs_partitions' is the result of partitions. // NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring // specific dimensions in inputs have the identical partition should have individual implementation. -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -932,7 +932,7 @@ Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& in MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size(); Shapes inputs_partitions; size_t global_index = 0; - for (auto& shape : inputs_shape) { + for (auto &shape : inputs_shape) { Shape tmp_partition; for (size_t j = 0; j < shape.size(); ++j) { tmp_partition.push_back(combined_partitions[global_index]); @@ -974,8 +974,8 @@ Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& in // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -1025,7 +1025,7 @@ Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_sh return SUCCESS; } -Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { +Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { if (InitForCostModel(strategy) == FAILED) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed."; @@ -1063,8 +1063,8 @@ int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { return is_output_parameter_involve_; } is_parameter_involve_ = is_parameter_; - const auto& prev_edges = this->GetAlivePrevEdges(); - for (auto& p_edge : prev_edges) { + const auto &prev_edges = this->GetAlivePrevEdges(); + for (auto &p_edge : prev_edges) { auto input_index = p_edge->next_op_input_index(); auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved(); if (input_index >= is_parameter_involve_.size()) { @@ -1090,7 +1090,7 @@ int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { return is_output_parameter_involve_; } -Status OperatorInfo::set_is_parameter(const std::vector& is_parameter) { +Status OperatorInfo::set_is_parameter(const std::vector &is_parameter) { if (is_parameter.size() != inputs_shape_.size()) { MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size() << " do not have the same number of inputs_shape_: " << inputs_shape_.size(); @@ -1111,7 +1111,7 @@ Status OperatorInfo::CalculateMemoryCost() { operator_cost()->set_is_parameter_involve(is_parameter_involve_); operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); // Set the memory cost in the 'strategy_cost_' - for (auto& swc : strategy_cost_) { + for (auto &swc : strategy_cost_) { auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); swc->cost_list[0]->memory_with_reuse_ = mem_cost; } @@ -1119,7 +1119,7 @@ Status OperatorInfo::CalculateMemoryCost() { } Status OperatorInfo::CorrectMemoryCost(size_t input_index) { - for (auto& swc : strategy_cost_) { + for (auto &swc : strategy_cost_) { double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * static_cast(operator_cost()->inputs_type_lengths()[input_index]); swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; @@ -1132,13 +1132,13 @@ Status OperatorInfo::CorrectMemoryCost(size_t input_index) { return SUCCESS; } -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape& dev_matrix_shape, const Shape& tensor_map) { +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) { int32_t ret = -1; // The number of repetitions is equal to the number of all devices divided by the number of devices use for // tensor map. int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); - for (auto& element : tensor_map) { + for (auto &element : tensor_map) { // -1 means the corresponding dimension is not split. if (element == MAP_NONE) { continue; @@ -1211,8 +1211,8 @@ Status OperatorInfo::InferVirtualDivOps() { return SUCCESS; } -Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths) { +Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { if (input_lengths.size() != inputs_shape_.size()) { MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size() << " do not have the same number of inputs shape: " << inputs_shape_.size(); @@ -1229,7 +1229,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector& inpu return SUCCESS; } -Status OperatorInfo::set_outputs_type(const std::vector& outputs_type) { +Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) { if (outputs_type.size() != outputs_shape_.size()) { MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() << " do not have the same number of outputs shape: " << outputs_shape_.size(); @@ -1239,7 +1239,7 @@ Status OperatorInfo::set_outputs_type(const std::vector& outputs_type) return SUCCESS; } -void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra, const CostPtr& cost) { +void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { CheckGlobalDeviceManager(); auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 19e0eeeda1e..347da7e573a 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -69,23 +69,23 @@ class OperatorInfo { virtual ~OperatorInfo() = default; - Status set_is_parameter(const std::vector& is_parameter); - Status SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths); + Status set_is_parameter(const std::vector &is_parameter); + Status SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths); // Set outputs dtype. // If only one output, outputs_type.size() is 1. // If output is tuple, outputs_type.size() is greater than 1. - Status set_outputs_type(const std::vector& outputs_type); - const std::vector& outputs_type() const { return outputs_type_; } - virtual Status Init(const StrategyPtr& strategy) = 0; - virtual Status InitForCostModel(const StrategyPtr& strategy) = 0; // only init the necessary parts + Status set_outputs_type(const std::vector &outputs_type); + const std::vector &outputs_type() const { return outputs_type_; } + virtual Status Init(const StrategyPtr &strategy) = 0; + virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts // Given the stage_id (which indicates the number of devices), // generate all strategies for this operator virtual Status GenerateStrategies(int32_t stage_id) = 0; - const OperatorCostPtr& operator_cost() const { return operator_cost_; } - void set_cost(const OperatorCostPtr& cost) { operator_cost_ = cost; } - virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; + const OperatorCostPtr &operator_cost() const { return operator_cost_; } + void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } + virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; virtual std::shared_ptr>> GenerateBatchStrategies(); virtual void ReComputeBatchSplitFlagList(); @@ -94,7 +94,7 @@ class OperatorInfo { double GetForwardMemoryCostFromCNode(); // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy // is checked - Status SetCostUnderStrategyBase(const StrategyPtr& strategy); + Status SetCostUnderStrategyBase(const StrategyPtr &strategy); std::vector> GetStrategyCost() { return strategy_cost_; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. @@ -104,61 +104,61 @@ class OperatorInfo { ForwardOp forward_op() const { return forward_op_; } ForwardOp replace_op() const { return replace_op_; } OutPutInfoVector replace_op_info() const { return replace_op_info_; } - virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; } + virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; } MirrorOps mirror_ops() const { return mirror_ops_; } Ops sub_ops() const { return sub_ops_; } VirtualDivOp virtual_div_op() const { return virtual_div_op_; } Shape dev_matrix_shape() const { return dev_matrix_shape_; } std::vector inputs_tensor_info() const { return inputs_tensor_info_; } std::vector outputs_tensor_info() const { return outputs_tensor_info_; } - const std::string& name() const { return name_; } - void set_name(const std::string& name) { name_ = name; } + const std::string &name() const { return name_; } + void set_name(const std::string &name) { name_ = name; } RankList global_device_list() const { return global_device_list_; } - void AddSuccEdge(const std::shared_ptr& e) { succ_edges_.push_back(e); } - void AddPrevEdge(const std::shared_ptr& e) { prev_edges_.push_back(e); } + void AddSuccEdge(const std::shared_ptr &e) { succ_edges_.push_back(e); } + void AddPrevEdge(const std::shared_ptr &e) { prev_edges_.push_back(e); } std::vector> succ_edges() const { return succ_edges_; } std::vector> prev_edges() const { return prev_edges_; } std::vector> GetAliveSuccEdges(); std::vector> GetAlivePrevEdges(); - void ReplacePreEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplaceSuccEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplacePreEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplaceSuccEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); + void ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); std::vector GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } - void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { + void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) { selected_strategy_ = s_strategy; selected_cost_ = cost; } StrategyPtr selected_strategy() const { return selected_strategy_; } CostPtr selected_cost() const { return selected_cost_; } - Status InitSelectedStrategy(const StrategyPtr& s_strategy) { return Init(s_strategy); } - void set_input_value(const std::vector& input_value) { input_value_ = input_value; } - void set_outputs_dtype(const TypePtr& dtype) { outputs_dtype_ = dtype; } - void set_cnode(const CNodePtr& cnode) { cnode_ = cnode; } + Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } + void set_input_value(const std::vector &input_value) { input_value_ = input_value; } + void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } + void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } bool is_alive() const { return is_alive_; } void SetNotAlive() { is_alive_ = false; } StrategyPtr strategy() const { return strategy_; } - void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; } + void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; } void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } - const std::string& refkey_parameter_name() const { return refkey_parameter_name_; } + const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated // multiple times. This method is to correct this, and makes the cost is calulated only once. Status CorrectMemoryCost(size_t input_index); int is_output_parameter_involve() const { return is_output_parameter_involve_; } int used_devices() const { return used_devices_; } // needed by rec_parser - void set_type(const std::string& type) { type_ = type; } - const std::string& type() const { return type_; } - void set_cnode_name(const std::string& cnode_name) { cnode_name_ = cnode_name; } - const std::string& cnode_name() const { return cnode_name_; } - const std::unordered_map& attrs() const { return attrs_; } + void set_type(const std::string &type) { type_ = type; } + const std::string &type() const { return type_; } + void set_cnode_name(const std::string &cnode_name) { cnode_name_ = cnode_name; } + const std::string &cnode_name() const { return cnode_name_; } + const std::unordered_map &attrs() const { return attrs_; } protected: // needed by rec_parser std::string type_; std::string cnode_name_; - virtual Status CheckStrategy(const StrategyPtr& strategy) = 0; + virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; virtual Status InferTensorMap() = 0; virtual Status InferForwardCommunication() = 0; virtual Status InferMirrorOps() = 0; @@ -167,14 +167,14 @@ class OperatorInfo { virtual Status InferDevMatrixShape() = 0; void SetDeviceListByStrategy(); void SetRepeatedCalcDevMatrix(); - Status CreateGroupByTensorMap(const Shape& tensor_map, std::vector* group); - Status CreateGroupByDim(size_t axis, std::vector* group); + Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); + Status CreateGroupByDim(size_t axis, std::vector *group); Status InferAttrs(); void ResetQueueMember(); - Status InitWithAutoRepeatCalc(const StrategyPtr& strategy); - Status InitWithManualRepeatCalc(const StrategyPtr& strategy); - Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strategy); - Status InitForCostModelWithManualRepeatCalc(const StrategyPtr& strategy); + Status InitWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitWithManualRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy); Status InferRepeatedCalcInfo(); Status InferVirtualDivOps(); @@ -182,9 +182,9 @@ class OperatorInfo { // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output // is used for grad and overload the function. If the output is a scalar, need to override the function too. virtual Status InferAsLossDivisor(); - Status InferSliceShape(const Strategys& inputs_strategy, const Strategys& outputs_strategy, - Shapes* inputs_slice_shape, Shapes* outputs_slice_shape); - void BreakingTiesForPerferringDataParallel(const StrategyPtr&, const CostPtr&); + Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); + void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); std::string name_; Shapes inputs_shape_; @@ -242,29 +242,29 @@ class OperatorInfo { std::vector outputs_type_; }; -Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy); -Status CheckStrategyValue(const StrategyPtr& strategy, const Shapes& inputs_shape, bool); +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool); Operator CreateVirtualDivOp(int32_t div_num); -Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& group); -Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout); -OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num); -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape& dev_matrix_shape, const Shape& tensor_map); +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes& shapes, const std::vector& split_flag_list); + const Shapes &shapes, const std::vector &split_flag_list); -void PrintStrategy(const StrategyPtr& strategy); +void PrintStrategy(const StrategyPtr &strategy); // generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* sp_vector); +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *sp_vector); // generate strategies for that have two inputs, and input0 or input1 maybe broadcast, // and the corresponding dimensions that are not broadcast are all relevant dimensions // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* sp_vector); +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *sp_vector); -Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc index a4d601dbe93..fed361616ba 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc @@ -34,7 +34,7 @@ namespace parallel { * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. * the strategy of w should equal to the channel dimension of strategy of A */ -Status PReLUInfo::CheckStrategy(const StrategyPtr& strategy) { +Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -119,7 +119,7 @@ Dimensions PReLUInfo::GetOutputStrategy() { return output_strategy; } -Status PReLUInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if (inputs_layout == nullptr || outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -181,7 +181,7 @@ Status PReLUInfo::GetAttrs() { return SUCCESS; } -Status PReLUInfo::Init(const StrategyPtr& strategy) { +Status PReLUInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -190,7 +190,7 @@ Status PReLUInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status PReLUInfo::InitForCostModel(const StrategyPtr& strategy) { +Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -224,7 +224,7 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; @@ -234,7 +234,7 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/parallel/ops_info/prelu_info.h index 396407c1ee0..28e149fad76 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.h +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.h @@ -33,24 +33,24 @@ namespace parallel { */ class PReLUInfo : public OperatorInfo { public: - PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~PReLUInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Dimensions GetOutputStrategy(); diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc index 4cb81ee7699..d6e1c277ef5 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status ReshapeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -137,7 +137,7 @@ Status ReshapeInfo::GetParameterInput() { return FAILED; } - for (auto& element : elements) { + for (auto &element : elements) { MS_EXCEPTION_IF_NULL(element); if (element->isa()) { int32_t axis = element->cast()->value(); @@ -216,7 +216,7 @@ Strategys ReshapeInfo::GetOutputsStrategy() { return outputs_strategy; } -Status ReshapeInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if (inputs_layout == nullptr || outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -302,7 +302,7 @@ void ReshapeInfo::InferTensorInfoByLayout() { */ Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } -void ReshapeInfo::device_number(const StrategyPtr& strategy) { +void ReshapeInfo::device_number(const StrategyPtr &strategy) { int32_t stage = 0; if (strategy != nullptr) { stage = strategy->GetInputStage(); @@ -313,7 +313,7 @@ void ReshapeInfo::device_number(const StrategyPtr& strategy) { MS_ASSERT(dev_num_ > 0); } -Status ReshapeInfo::InferDefaultLayout(const Shape& shape, TensorLayout* const layout) { +Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { std::vector tensor_map_index; for (size_t i = 0; i < shape.size(); i++) { tensor_map_index.push_back(MAP_NONE); @@ -326,7 +326,7 @@ Status ReshapeInfo::InferDefaultLayout(const Shape& shape, TensorLayout* const l return Status::SUCCESS; } -Status ReshapeInfo::Init(const StrategyPtr& strategy) { +Status ReshapeInfo::Init(const StrategyPtr &strategy) { ResetQueueMember(); device_number(strategy); if (strategy) { @@ -375,7 +375,7 @@ Status ReshapeInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status ReshapeInfo::InitForCostModel(const StrategyPtr& strategy) { +Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -389,7 +389,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; @@ -423,7 +423,7 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h index 3864d2b93d6..99ee0141756 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.h @@ -34,34 +34,34 @@ namespace parallel { */ class ReshapeInfo : public OperatorInfo { public: - ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), dev_num_(0), input_layout_set_flag_(false), output_layout_set_flag_(false) {} ~ReshapeInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - void SetInputLayout(const TensorLayout& input_layout) { + Status Init(const StrategyPtr &strategy) override; + void SetInputLayout(const TensorLayout &input_layout) { input_layout_ = input_layout; input_layout_set_flag_ = true; } - void SetOutputLayout(const TensorLayout& output_layout) { + void SetOutputLayout(const TensorLayout &output_layout) { output_layout_ = output_layout; output_layout_set_flag_ = true; } - Status InitForCostModel(const StrategyPtr& strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorMap() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Strategys GetOutputsStrategy(); @@ -69,8 +69,8 @@ class ReshapeInfo : public OperatorInfo { Status GetParameterInput(); Status ComputeReplaceOp(); void InferTensorInfoByLayout(); - void device_number(const StrategyPtr& strategy); - Status InferDefaultLayout(const Shape& shape, TensorLayout* const layout); + void device_number(const StrategyPtr &strategy); + Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); int32_t dev_num_; std::vector parameter_input_v_; diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h index 3682fe334fc..f7895d05112 100644 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h +++ b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h @@ -32,19 +32,19 @@ class TmpIdentityInfo : public OperatorInfo { // consider this parameter tensor as TmpIdentityInfo operator. TmpIdentityInfo operator tasks as input a tensor, // and outputs the same tensor. After the transformation, subsequent operators can share the output tensor. public: - TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs, - const std::string& name = IDENTITY_INFO) + TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, + const std::string &name = IDENTITY_INFO) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TmpIdentityInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override { return SUCCESS; } Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/parallel/ops_info/transpose_info.cc index 84333a1337f..49bbae0cb4e 100644 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/transpose_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status TransposeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -43,7 +43,7 @@ Status TransposeInfo::CheckStrategy(const StrategyPtr& strategy) { Status TransposeInfo::InferDevMatrixShape() { std::vector stra = strategy_->GetInputDim(); input_strategy_ = stra.at(0); - for (auto& iter : input_strategy_) { + for (auto &iter : input_strategy_) { dev_matrix_shape_.push_back(iter); } return SUCCESS; @@ -77,7 +77,7 @@ Status TransposeInfo::ComputeAxis() { return FAILED; } axis_v_.clear(); - for (auto& element : elements) { + for (auto &element : elements) { MS_EXCEPTION_IF_NULL(element); if (element->isa()) { int32_t axis = element->cast()->value(); @@ -130,7 +130,7 @@ Strategys TransposeInfo::GetOutputsStrategy() { return outputs_strategy; } -Status TransposeInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status TransposeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -179,7 +179,7 @@ Status TransposeInfo::InferTensorInfo() { // compute axis_v_ during this method Status TransposeInfo::GetAttrs() { return ComputeAxis(); } -Status TransposeInfo::Init(const StrategyPtr& strategy) { +Status TransposeInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -188,7 +188,7 @@ Status TransposeInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status TransposeInfo::InitForCostModel(const StrategyPtr& strategy) { +Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -202,7 +202,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; @@ -234,7 +234,7 @@ Status TransposeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << "strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/parallel/ops_info/transpose_info.h index e4e2b90b7bc..50b76bde650 100644 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.h +++ b/mindspore/ccsrc/parallel/ops_info/transpose_info.h @@ -33,23 +33,23 @@ namespace parallel { */ class TransposeInfo : public OperatorInfo { public: - TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TransposeInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Strategys GetOutputsStrategy(); diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc index cd3b40315c1..4b695ba62d3 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -171,7 +171,7 @@ Status VirtualDatasetInfo::InferTensorInfo() { Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; } -Status VirtualDatasetInfo::Init(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) { if (InitWithManualRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -179,7 +179,7 @@ Status VirtualDatasetInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -199,7 +199,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { } } -Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; @@ -223,7 +223,7 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); StrategyPtr sp; std::vector strategy; - for (auto& shape : inputs_shape_) { + for (auto &shape : inputs_shape_) { Shape temp; temp.emplace_back(SizeToInt(total_dev_num)); (void)temp.insert(temp.end(), shape.size() - 1, 1); diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h index 398bae3585d..312ac7a6a47 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h @@ -30,19 +30,19 @@ namespace mindspore { namespace parallel { class VirtualDatasetInfo : public OperatorInfo { public: - VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~VirtualDatasetInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; void ReComputeBatchSplitFlagList() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index bcd4dc3763b..d1390db899b 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -76,7 +76,7 @@ void SetCommunicationOpGroupLabel(std::vector new_node_input) { } } -std::vector CreateInput(const Operator& op, const AnfNodePtr& node, const std::string& instance_name) { +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { MS_EXCEPTION_IF_NULL(node); OperatorArgs arg_forward = op.second; ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name); @@ -85,7 +85,7 @@ std::vector CreateInput(const Operator& op, const AnfNodePtr& node, std::vector new_node_input = {NewValueNode(pyop_instance), node}; if (!params.empty()) { - for (auto& param : params) { + for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); MS_EXCEPTION_IF_NULL(val); int32_t position = param.second; @@ -98,8 +98,8 @@ std::vector CreateInput(const Operator& op, const AnfNodePtr& node, return new_node_input; } -void InsertNode(const Operator& op, const CNodePtr& node, size_t index, const AnfNodePtr& pre_node, - const FuncGraphPtr& func_graph, const std::string& instance_name) { +void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node, + const FuncGraphPtr &func_graph, const std::string &instance_name) { // insert new node before the node FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -121,7 +121,7 @@ void InsertNode(const Operator& op, const CNodePtr& node, size_t index, const An manager->SetEdge(node, SizeToInt(index), new_node); } -std::string CreateInstanceName(const CNodePtr& node, size_t index) { +std::string CreateInstanceName(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; @@ -132,7 +132,7 @@ std::string CreateInstanceName(const CNodePtr& node, size_t index) { return instance_name; } -void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); // step1:get graph manager distribute_operator FuncGraphPtr func_graph = node->func_graph(); @@ -141,7 +141,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { MS_EXCEPTION_IF_NULL(manager); auto uses_set = manager->node_users()[node]; CNodePtr node_to_insert = node; - for (auto& uses_pair : uses_set) { + for (auto &uses_pair : uses_set) { auto uses_cnode = uses_pair.first->cast(); MS_EXCEPTION_IF_NULL(uses_cnode); if (!IsValueNode(uses_cnode->input(0))) { @@ -175,7 +175,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { } } -CNodePtr InsertMakeTuple(const AnfNodePtr& prev, uint32_t num, const FuncGraphPtr& func_graph) { +CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint32_t num, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(prev); MS_EXCEPTION_IF_NULL(func_graph); std::vector make_tuple_inputs; @@ -195,8 +195,8 @@ CNodePtr InsertMakeTuple(const AnfNodePtr& prev, uint32_t num, const FuncGraphPt return make_tuple; } -void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_ptr, const CNodePtr& node, - const FuncGraphPtr& func_graph, int pos, const CNodePtr& pre_node) { +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(pre_node); MS_EXCEPTION_IF_NULL(func_graph); @@ -226,8 +226,8 @@ void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_p } } -void InsertGetTensorSliceOp(const Operator& op, const CNodePtr& node, const FuncGraphPtr& func_graph, int pos, - const std::string& instance_name) { +void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int pos, + const std::string &instance_name) { if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name; } @@ -244,8 +244,8 @@ void InsertGetTensorSliceOp(const Operator& op, const CNodePtr& node, const Func InsertNode(op, node, IntToSize(pos), pre_node, func_graph, instance_name); } -TensorLayout GetTensorInLayout(const CNodePtr& middle_node, const PrimitivePtr& middle_prim, - const OperatorInfoPtr& distribute_operator) { +TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, + const OperatorInfoPtr &distribute_operator) { TensorInfo tensorinfo_in; if (middle_prim->name() == TUPLE_GETITEM) { auto value_node = middle_node->input(2)->cast(); @@ -265,7 +265,7 @@ TensorLayout GetTensorInLayout(const CNodePtr& middle_node, const PrimitivePtr& return tensorinfo_in.tensor_layout(); } -OperatorInfoPtr GetDistributeOperator(const CNodePtr& node) { +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!IsParallelCareNode(node)) { return nullptr; @@ -277,9 +277,9 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr& node) { return distribute_operator; } -void Redistribution(const std::pair& node_pair, const OperatorInfoPtr& distribute_operator, - const CNodePtr& middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr& pre_node) { +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node) { FuncGraphPtr func_graph = middle_node->func_graph(); if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "Redistribution:get graph failed"; @@ -333,13 +333,13 @@ bool StrategyFound(std::unordered_map attrs) { return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); } -bool IsCommunicationOp(const PrimitivePtr& prim) { +bool IsCommunicationOp(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()); } -bool FindCommunicationOp(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +bool FindCommunicationOp(const std::vector &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -364,7 +364,7 @@ bool FindCommunicationOp(const std::vector& all_nodes) { return false; } -bool IsParallelCareNode(const CNodePtr& cnode) { +bool IsParallelCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); ValueNodePtr prim_node = cnode->input(0)->cast(); if (prim_node == nullptr) { @@ -389,8 +389,8 @@ bool IsParallelCareNode(const CNodePtr& cnode) { return cnode->in_forward_flag(); } -void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_operator, const CNodePtr& insert_node, - const TensorRedistribution& tensor_redistribution, const CNodePtr& pre_node) { +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) { MS_EXCEPTION_IF_NULL(node->func_graph()); FuncGraphManagerPtr manager = node->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -406,7 +406,7 @@ void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_ insert_node_new = insert_node; } MS_EXCEPTION_IF_NULL(insert_node_new); - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(use_cnode); if (!IsValueNode(use_cnode->input(0))) { @@ -429,7 +429,7 @@ void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_ } } -void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) { +void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(next_node); OperatorInfoPtr op_info = next_node->operator_info(); @@ -474,11 +474,11 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) { } } -void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) { +void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_cnode = node_pair.first->cast(); if (use_cnode == nullptr || !IsValueNode(use_cnode->input(0))) { continue; @@ -496,8 +496,8 @@ void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) } } -std::vector ReplaceOpInput(const Operator& replace_op, const std::string& instance_name, - const CNodePtr& node) { +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node) { OperatorArgs arg_replace_op = replace_op.second; ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); if (pyop_instance == nullptr) { @@ -518,7 +518,7 @@ std::vector ReplaceOpInput(const Operator& replace_op, const std::st if (first_position == 1) { replace_input.pop_back(); } - for (auto& param : params) { + for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); if (val == nullptr) { MS_LOG(EXCEPTION) << "Failure:val is nullptr"; @@ -531,7 +531,7 @@ std::vector ReplaceOpInput(const Operator& replace_op, const std::st return replace_input; } -void ReplaceOneOp(const Operator& replace_op, const CNodePtr& node) { +void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = func_graph->manager(); @@ -551,7 +551,7 @@ void ReplaceOneOp(const Operator& replace_op, const CNodePtr& node) { (void)manager->Replace(node, replace_node); } -void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node) { +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { // step1:get graph manager distribute_operator OperatorInfoPtr distribute_operator = node->operator_info(); if (distribute_operator == nullptr) { @@ -599,15 +599,15 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node) { MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name(); } -bool IsSomePrimitive(const CNodePtr& cnode, const std::string& name) { +bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { ValueNodePtr anf_node = cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(anf_node); PrimitivePtr prim = anf_node->value()->cast(); return (prim->name() == name); } -void StepReplaceGraph(const std::shared_ptr, AnfNodePtr>>& replace_graph, - const CNodePtr& node) { +void StepReplaceGraph(const std::shared_ptr, AnfNodePtr>> &replace_graph, + const CNodePtr &node) { MS_EXCEPTION_IF_NULL(replace_graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(replace_graph->second); @@ -627,7 +627,7 @@ void StepReplaceGraph(const std::shared_ptr, A if (replace_graph->first.size() != 2) { MS_LOG(EXCEPTION) << "Failure:replace_graph->first.size() must be 2 for OneHot Primitive!"; } - for (auto& replace_input : replace_graph->first) { + for (auto &replace_input : replace_graph->first) { MS_EXCEPTION_IF_NULL(replace_input); manager->SetEdge(replace_input, 1, pre_node); CNodePtr replace_input_cnode = replace_input->cast(); @@ -645,7 +645,7 @@ void StepReplaceGraph(const std::shared_ptr, A replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node } -int32_t GetTupleGetItemIndex(const CNodePtr& cnode) { +int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() != 3) { MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3"; @@ -666,7 +666,7 @@ int32_t GetTupleGetItemIndex(const CNodePtr& cnode) { // Judge whether the node is a loss, and if there are multiple outputs, // get which output is a grad according to the tuple getitem. // Currently, it is not supported that the sens is a tuple. -LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { +LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { MS_EXCEPTION_IF_NULL(loss_node); FuncGraphPtr sub_graph = loss_node->func_graph(); MS_EXCEPTION_IF_NULL(sub_graph); @@ -718,7 +718,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { MS_LOG(EXCEPTION) << "Invalid loss"; } -void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node) { +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); FuncGraphPtr func_graph = node->func_graph(); @@ -742,7 +742,7 @@ void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node } } -std::pair FindParameter(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { if (!node->isa() && !node->isa() && !node->isa()) { return std::make_pair(nullptr, false); } else if (node->isa()) { @@ -790,7 +790,7 @@ std::pair FindParameter(const AnfNodePtr& node, const FuncGrap return std::make_pair(nullptr, false); } -std::pair FindCNode(const AnfNodePtr& anode, const std::string& name, const FuncGraphPtr& func_graph) { +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(anode); MS_EXCEPTION_IF_NULL(anode->func_graph()); FuncGraphManagerPtr manager = anode->func_graph()->manager(); @@ -798,7 +798,7 @@ std::pair FindCNode(const AnfNodePtr& anode, const std::string& AnfNodeIndexSet node_set = manager->node_users()[anode]; bool result = false; CNodePtr cnode_return = nullptr; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_apply = node_pair.first->cast(); if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; @@ -820,7 +820,7 @@ std::pair FindCNode(const AnfNodePtr& anode, const std::string& return std::make_pair(result, cnode_return); } -bool IsCastBeforMirror(const CNodePtr& node, size_t index) { +bool IsCastBeforMirror(const CNodePtr &node, size_t index) { // only if cast_before_mirror is true, pre node is cast and type is not float32 return true if (!ParallelContext::GetInstance()->cast_before_mirror()) { return false; @@ -850,7 +850,7 @@ bool IsCastBeforMirror(const CNodePtr& node, size_t index) { return (type_id != kNumberTypeFloat32); } -void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); FuncGraphPtr func_graph = node->func_graph(); @@ -887,7 +887,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { } std::string instance_name = MIRROR_OP; if (IsCastBeforMirror(node, index)) { - for (auto& op : backward_op) { + for (auto &op : backward_op) { // insert new node before the node CNodePtr cnode = node->input(index)->cast(); MS_EXCEPTION_IF_NULL(cnode); @@ -895,7 +895,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); } } else { - for (auto& op : backward_op) { + for (auto &op : backward_op) { AnfNodePtr pre_node = node->input(index); InsertNode(op, node, index, pre_node, func_graph, instance_name); } @@ -903,7 +903,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { } } -void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNodePtr& node, bool is_loss_node) { +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(node); MirrorOps mirror_ops = distribute_operator->mirror_ops(); @@ -920,7 +920,7 @@ void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNo } } -std::string GetDisOpName(const std::string& prim_name) { +std::string GetDisOpName(const std::string &prim_name) { std::string op_name = prim_name; if (!prim_name.empty() && (prim_name[0] == '_')) { op_name = prim_name.substr(1); @@ -928,8 +928,8 @@ std::string GetDisOpName(const std::string& prim_name) { return op_name + "Info"; } -OperatorInfoPtr OperatorInstanceByName(const std::string& name, const PrimitiveAttrs& attrs, - const std::vector& shape_list) { +OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { if (shape_list.size() != 2) { MS_LOG(ERROR) << "The size of shape list is not 2"; return nullptr; @@ -951,8 +951,8 @@ OperatorInfoPtr OperatorInstanceByName(const std::string& name, const PrimitiveA return operator_; } -OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, - const std::vector& shape_list) { +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { MS_EXCEPTION_IF_NULL(prim); OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list); if (operator_ == nullptr) { @@ -963,7 +963,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& return operator_; } -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, std::vector shape_list) { OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); for (size_t i = 0; i < shape_list[0].size(); ++i) { @@ -992,7 +992,7 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { std::vector value_vector = value_tuple->value(); (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), - [](const ValuePtr& value) { return static_cast(GetValue(value)); }); + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); strategy.push_back(dim); } else { MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; @@ -1007,7 +1007,7 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { return strategyPtr; } -Shapes GetNodeShape(const AnfNodePtr& node) { +Shapes GetNodeShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shapes; BaseShapePtr base_shape_ptr = node->Shape(); @@ -1039,7 +1039,7 @@ Shapes GetNodeShape(const AnfNodePtr& node) { auto tuple_shape_ptr = dyn_cast(base_shape_ptr); if (tuple_shape_ptr != nullptr) { auto tuple_shape = tuple_shape_ptr->shape(); - for (auto& shape : tuple_shape) { + for (auto &shape : tuple_shape) { auto each_shape = dyn_cast(shape); MS_EXCEPTION_IF_NULL(each_shape); shapes.push_back(each_shape->shape()); @@ -1052,7 +1052,7 @@ Shapes GetNodeShape(const AnfNodePtr& node) { return shapes; } -std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(func_graph); std::vector parameters; @@ -1075,7 +1075,7 @@ std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr root_g = roots.back(); MS_EXCEPTION_IF_NULL(root_g); - for (auto& param_node : root_g->parameters()) { + for (auto ¶m_node : root_g->parameters()) { auto param = param_node->cast(); if (param && (name == param->name())) { parameters.push_back(param_node); @@ -1088,7 +1088,7 @@ std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const return parameters; } -Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(func_graph); @@ -1107,7 +1107,7 @@ Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph return input_shapes; } -std::vector ExtractShape(const CNodePtr& node) { +std::vector ExtractShape(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shape_inputs, shape_outputs; std::vector shape_all; @@ -1145,14 +1145,14 @@ std::vector ExtractShape(const CNodePtr& node) { return shape_all; } -std::pair FindParallelCareNode(const AnfNodePtr& node) { +std::pair FindParallelCareNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -1174,7 +1174,7 @@ std::pair FindParallelCareNode(const AnfNodePtr& node) { return std::make_pair(nullptr, 0); } -std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNodePtr& parameter) { +std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(parameter); FuncGraphManagerPtr manager = graph->manager(); @@ -1184,7 +1184,7 @@ std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNode return prim_anf_node_pair; } else { AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; - for (auto& param_pair : param_sub_set) { + for (auto ¶m_pair : param_sub_set) { CNodePtr graph_cnode = param_pair.first->cast(); if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa()) { continue; @@ -1208,7 +1208,7 @@ std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNode return std::make_pair(nullptr, 0); } -void SetParallelShape(const AnfNodePtr& parameter, const std::pair& res) { +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { MS_EXCEPTION_IF_NULL(parameter); AbstractBasePtr abstract = parameter->abstract(); MS_EXCEPTION_IF_NULL(abstract); @@ -1237,10 +1237,10 @@ void SetParallelShape(const AnfNodePtr& parameter, const std::pairset_tensor_layout(std::make_shared(tensor_layout)); } -void CoverSliceShape(const FuncGraphPtr& root) { +void CoverSliceShape(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); auto parameters = root->parameters(); - for (auto& parameter : parameters) { + for (auto ¶meter : parameters) { MS_EXCEPTION_IF_NULL(parameter->Shape()); auto iter = g_RefMap.find(parameter); if (iter != g_RefMap.end()) { @@ -1258,7 +1258,7 @@ void CoverSliceShape(const FuncGraphPtr& root) { g_RefMap.clear(); } -bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) { +bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_node) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(parameter_node); FuncGraphManagerPtr manager = root->manager(); @@ -1281,9 +1281,9 @@ bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_nod return true; } -void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { +void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); - for (auto& cloned_parameter_node : root->parameters()) { + for (auto &cloned_parameter_node : root->parameters()) { MS_EXCEPTION_IF_NULL(cloned_parameter_node); auto cloned_parameter = cloned_parameter_node->cast(); MS_EXCEPTION_IF_NULL(cloned_parameter); @@ -1300,7 +1300,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { bool found_be_cloned_parameter = false; ParameterPtr cloned_from_parameter = nullptr; AnfNodePtr cloned_from_node = nullptr; - for (auto& be_cloned_parameter_node : root->parameters()) { + for (auto &be_cloned_parameter_node : root->parameters()) { MS_EXCEPTION_IF_NULL(be_cloned_parameter_node); auto be_cloned_parameter = be_cloned_parameter_node->cast(); MS_EXCEPTION_IF_NULL(be_cloned_parameter); @@ -1315,7 +1315,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { // get the be cloned index py::list be_cloned_index = parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED_INDEX); - for (auto& index : be_cloned_index) { + for (auto &index : be_cloned_index) { if (cloned_index == py::cast(index)) { found_be_cloned_parameter = true; cloned_from_parameter = be_cloned_parameter; @@ -1341,7 +1341,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { } } -void SetVirtualDatasetStrategy(const CNodePtr& node) { +void SetVirtualDatasetStrategy(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); PrimitivePtr prim = GetValueNode(node->input(0)); MS_EXCEPTION_IF_NULL(prim); @@ -1370,8 +1370,8 @@ void SetVirtualDatasetStrategy(const CNodePtr& node) { } } -void ExtractInformation(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void ExtractInformation(const std::vector &all_nodes) { + for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; @@ -1390,7 +1390,7 @@ void ExtractInformation(const std::vector& all_nodes) { if (operator_ == nullptr) { MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); std::vector input_value; for (size_t index = 1; index < inputs.size(); ++index) { if (inputs[index]->isa()) { @@ -1440,7 +1440,7 @@ void ExtractInformation(const std::vector& all_nodes) { } } -TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair) { +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair) { CNodePtr cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); @@ -1456,13 +1456,13 @@ TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair } // if reshape's output connect to several primitive, return the first layout found -std::shared_ptr FindNextLayout(const CNodePtr& cnode) { +std::shared_ptr FindNextLayout(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode->func_graph()); FuncGraphManagerPtr manager = cnode->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[cnode]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_apply = node_pair.first->cast(); if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; @@ -1492,7 +1492,7 @@ std::shared_ptr FindNextLayout(const CNodePtr& cnode) { return nullptr; } -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, size_t output_index) { +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) { MS_EXCEPTION_IF_NULL(cnode); OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); MS_EXCEPTION_IF_NULL(distribute_operator); @@ -1505,7 +1505,7 @@ std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, si return std::make_shared(tensorlayout_out); } -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& node, size_t output_index) { +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) { if (!node->isa()) { return nullptr; } @@ -1523,7 +1523,7 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& n return nullptr; } -std::shared_ptr FindPrevLayout(const AnfNodePtr& node) { +std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { if (node->isa()) { MS_LOG(EXCEPTION) << "Failure: parameter before reshape is not supported temporary"; } @@ -1567,8 +1567,8 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr& node) { return nullptr; } -void ReshapeInit(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void ReshapeInit(const std::vector &all_nodes) { + for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; @@ -1608,7 +1608,7 @@ void ReshapeInit(const std::vector& all_nodes) { } // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -bool IsGradSensNode(const AnfNodePtr& node) { +bool IsGradSensNode(const AnfNodePtr &node) { if (!node->isa()) { return false; } @@ -1660,7 +1660,7 @@ bool IsGradSensNode(const AnfNodePtr& node) { return (expect_j_prim->name() == J); } -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr& loss_cnode) { +TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { MS_EXCEPTION_IF_NULL(loss_cnode); AnfNodePtr node = loss_cnode->cast(); MS_EXCEPTION_IF_NULL(node); @@ -1700,7 +1700,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr& loss_cnode) { return ret; } -void SplitSens(const AnfNodePtr& grad_sens_node, const TensorLayout& loss_grad_layout) { +void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { MS_EXCEPTION_IF_NULL(grad_sens_node); auto cnode = grad_sens_node->cast(); @@ -1752,7 +1752,7 @@ void SplitSens(const AnfNodePtr& grad_sens_node, const TensorLayout& loss_grad_l InsertGetTensorSliceOp(op, cnode, func_graph, 1, SPLIT_SENS); } -void InsertForwardOps(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); OperatorVector forward_op = distribute_operator->forward_op(); @@ -1762,7 +1762,7 @@ void InsertForwardOps(const OperatorInfoPtr& distribute_operator, const CNodePtr } } -void StepReplace(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); // StepReplaceOp @@ -1783,7 +1783,7 @@ void StepReplace(const OperatorInfoPtr& distribute_operator, const CNodePtr& cno } } -void HandleDropoutNode(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); @@ -1801,12 +1801,12 @@ void HandleDropoutNode(const OperatorInfoPtr& distribute_operator, const CNodePt ReplaceOneOp(replace_op, cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); } -void HandleSpecialNode(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { HandleDropoutNode(distribute_operator, cnode); } -void ParallelCommunication(const FuncGraphPtr& root, const std::vector& all_nodes, - const FuncGraphManagerPtr& manager) { +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(manager); TensorRedistribution tensor_redistribution; @@ -1817,7 +1817,7 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector(node); MS_EXCEPTION_IF_NULL(symbolic_key); auto all_upstream_node = root->manager()->node_users()[node]; - for (auto& upstream_node : all_upstream_node) { + for (auto &upstream_node : all_upstream_node) { FuncGraphPtr fg = upstream_node.first->func_graph(); if (symbolic_key->node()->isa()) { - for (auto& param : root->parameters()) { + for (auto ¶m : root->parameters()) { if (*param == *symbolic_key->node()) { AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param}); MS_EXCEPTION_IF_NULL(reverted_node); @@ -1889,9 +1889,9 @@ void RevertSymbolicKeyInstance(const FuncGraphPtr& root, const AnfNodePtr& node) } } // namespace -void HandleSymbolicKeyInstance(const FuncGraphPtr& root, const std::vector& all_nodes) { +void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector &all_nodes) { MS_EXCEPTION_IF_NULL(root); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { // revert back SymbolicKeyInstance to embed() primitive if (IsValueNode(node)) { RevertSymbolicKeyInstance(root, node); @@ -1900,13 +1900,13 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr& root, const std::vectorget_return(); auto all_nodes = DeepScopedGraphSearch(ret); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1931,7 +1931,7 @@ void CheckpointStrategy(const FuncGraphPtr& func_graph) { } } -void RestoreStrategy(const FuncGraphPtr& func_graph) { +void RestoreStrategy(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_LOG(INFO) << "Extract strategy from checkpoint begin"; StrategyMap straMap; @@ -1943,7 +1943,7 @@ void RestoreStrategy(const FuncGraphPtr& func_graph) { } auto ret = func_graph->get_return(); auto all_nodes = DeepScopedGraphSearch(ret); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1968,8 +1968,8 @@ void RestoreStrategy(const FuncGraphPtr& func_graph) { } } -void SetForwardFlag(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void SetForwardFlag(const std::vector &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -1986,8 +1986,8 @@ void SetForwardFlag(const std::vector& all_nodes) { } } -void SetForwardFlag(const AnfNodeSet& all_nodes) { - for (auto& node : all_nodes) { +void SetForwardFlag(const AnfNodeSet &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -2003,7 +2003,7 @@ void SetForwardFlag(const AnfNodeSet& all_nodes) { } } -CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { +CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); CNodePtr return_node = func_graph->get_return(); MS_EXCEPTION_IF_NULL(return_node); @@ -2059,8 +2059,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { return pre_cnode; } -FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet& root_all_nodes) { - for (auto& node : root_all_nodes) { +FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { + for (auto &node : root_all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -2088,11 +2088,11 @@ FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet& root_all_nodes) { return nullptr; } -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr& root) { +CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); AnfNodePtr root_return_node = root->get_return(); MS_EXCEPTION_IF_NULL(root_return_node); - const auto& all_nodes = root->nodes(); + const auto &all_nodes = root->nodes(); FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); if (func_graph == nullptr) { return FindLossCNode(root); @@ -2101,12 +2101,12 @@ CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr& root) { } } -FuncGraphPtr ForwardGraph(const FuncGraphPtr& root) { +FuncGraphPtr ForwardGraph(const FuncGraphPtr &root) { FuncGraphPtr forward_graph = root; MS_EXCEPTION_IF_NULL(root); AnfNodePtr root_return_node = root->get_return(); MS_EXCEPTION_IF_NULL(root_return_node); - const auto& all_nodes = root->nodes(); + const auto &all_nodes = root->nodes(); FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); if (func_graph != nullptr) { forward_graph = func_graph; @@ -2114,11 +2114,11 @@ FuncGraphPtr ForwardGraph(const FuncGraphPtr& root) { return forward_graph; } -void MarkForwardCNode(const FuncGraphPtr& root) { +void MarkForwardCNode(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); AnfNodePtr root_return_node = root->get_return(); MS_EXCEPTION_IF_NULL(root_return_node); - auto& all_nodes = root->nodes(); + auto &all_nodes = root->nodes(); FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); if (func_graph == nullptr) { @@ -2178,7 +2178,7 @@ Status ParallelInit() { return SUCCESS; } -bool StepParallel(const FuncGraphPtr& root, const opt::OptimizerPtr& optimizer) { +bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); @@ -2258,12 +2258,12 @@ bool StepParallel(const FuncGraphPtr& root, const opt::OptimizerPtr& optimizer) } // Needed by rec_parser -std::vector ExtractInputsTensorName(const CNodePtr& node) { +std::vector ExtractInputsTensorName(const CNodePtr &node) { std::vector name_inputs; std::vector all_inputs = node->inputs(); std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; - for (auto& input : node_inputs) { + for (auto &input : node_inputs) { std::string name; if (IsValueNode(input) || input->isa() || input->isa()) { name = input->ToString(); diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index fd47a59bf55..184d11d1737 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -41,114 +41,114 @@ struct LossNodeInfo { int dout_index = 0; // now don't support the sens is a tuple }; -std::vector CreateInput(const Operator& op, const AnfNodePtr& node, const std::string& instance_name); -std::string CreateInstanceName(const CNodePtr& node, size_t index); -void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node); +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); +std::string CreateInstanceName(const CNodePtr &node, size_t index); +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); -void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_ptr, const CNodePtr& node, - const FuncGraphPtr& func_graph, int pos, const CNodePtr& pre_node); +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node); -TensorLayout GetTensorInLayout(const CNodePtr& pre_node, const PrimitivePtr& pre_prim, - const OperatorInfoPtr& distribute_operator_pre); +TensorLayout GetTensorInLayout(const CNodePtr &pre_node, const PrimitivePtr &pre_prim, + const OperatorInfoPtr &distribute_operator_pre); -OperatorInfoPtr GetDistributeOperator(const CNodePtr& node); +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node); -void Redistribution(const std::pair& node_pair, const OperatorInfoPtr& distribute_operator, - const CNodePtr& middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr& pre_node); +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node); bool StrategyFound(std::unordered_map attrs); -bool IsParallelCareNode(const CNodePtr& cnode); +bool IsParallelCareNode(const CNodePtr &cnode); -void MarkForwardCNode(const FuncGraphPtr& root); +void MarkForwardCNode(const FuncGraphPtr &root); -bool FindCommunicationOp(const std::vector& all_nodes); +bool FindCommunicationOp(const std::vector &all_nodes); -void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_operator, const CNodePtr& insert_node, - const TensorRedistribution& tensor_redistribution, const CNodePtr& pre_node); +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node); -std::vector ReplaceOpInput(const Operator& replace_op, const std::string& instance_name, - const CNodePtr& node); +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node); -void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node); +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); -void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node); +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node); -std::pair FindParameter(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph); -std::pair FindCNode(const AnfNodePtr& anode, const std::string& name, const FuncGraphPtr& func_graph); +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); -void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node); +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); -void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNodePtr& node, bool is_loss_node); +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node); // Generate and init parallel operator -OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, - const std::vector& shape_list); +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list); // Generate without initing parallel operator -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, std::vector shape_list); // Extract strategy from attr StrategyPtr ExtractStrategy(std::unordered_map attrs); -Shapes GetNodeShape(const AnfNodePtr& node); +Shapes GetNodeShape(const AnfNodePtr &node); -std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); // Extract shape from anfnode -std::vector ExtractShape(const CNodePtr& node); +std::vector ExtractShape(const CNodePtr &node); -std::pair FindParallelCareNode(const AnfNodePtr& node); +std::pair FindParallelCareNode(const AnfNodePtr &node); // Find finally sub graph -std::pair FindSubGraph(const FuncGraphPtr& func_graph, const AnfNodePtr& parameter); +std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter); // Set distribute shape for parameters abstract -void SetParallelShape(const AnfNodePtr& parameter, const std::pair& res); +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); // change parameters'shape in resource -void CoverSliceShape(const FuncGraphPtr& root); +void CoverSliceShape(const FuncGraphPtr &root); -void SetVirtualDatasetStrategy(const CNodePtr& node); +void SetVirtualDatasetStrategy(const CNodePtr &node); // Creat parallel operator for primitive node(has strategy) -void ExtractInformation(const std::vector& all_nodes); +void ExtractInformation(const std::vector &all_nodes); -TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair); +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair); -std::shared_ptr FindNextLayout(const CNodePtr& node); +std::shared_ptr FindNextLayout(const CNodePtr &node); -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, size_t output_index); +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index); -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& node, size_t output_index); +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index); -std::shared_ptr FindPrevLayout(const AnfNodePtr& node); +std::shared_ptr FindPrevLayout(const AnfNodePtr &node); -void ReshapeInit(const std::vector& all_nodes); +void ReshapeInit(const std::vector &all_nodes); // Add node for whole graph -void ParallelCommunication(const FuncGraphPtr& root, const std::vector& all_nodes, - const FuncGraphManagerPtr& manager); +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager); -void RestoreStrategy(const FuncGraphPtr& func_graph); +void RestoreStrategy(const FuncGraphPtr &func_graph); -void CheckpointStrategy(const FuncGraphPtr& func_graph); +void CheckpointStrategy(const FuncGraphPtr &func_graph); // main step of Parallel -bool StepParallel(const FuncGraphPtr& func_graph, const opt::OptimizerPtr& optimizer); +bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); -int32_t GetTupleGetItemIndex(const CNodePtr& cnode); +int32_t GetTupleGetItemIndex(const CNodePtr &cnode); -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr& root); +CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root); Status ParallelInit(); -std::vector ExtractInputsTensorName(const CNodePtr& node); +std::vector ExtractInputsTensorName(const CNodePtr &node); -FuncGraphPtr ForwardGraph(const FuncGraphPtr& root); +FuncGraphPtr ForwardGraph(const FuncGraphPtr &root); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/strategy.h b/mindspore/ccsrc/parallel/strategy.h index 93d4d4dff1a..fce99305a5e 100644 --- a/mindspore/ccsrc/parallel/strategy.h +++ b/mindspore/ccsrc/parallel/strategy.h @@ -46,7 +46,7 @@ class Strategy { inputs_.push_back(inputs_[0]); } } - void ResetInputs(const std::vector& input) { inputs_ = input; } + void ResetInputs(const std::vector &input) { inputs_ = input; } private: const int32_t stage_; @@ -55,7 +55,7 @@ class Strategy { std::vector inputs_; }; -inline StrategyPtr NewStrategy(const int32_t stage, const std::vector& inputs) { +inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { return std::make_shared(stage, inputs); } } // namespace parallel diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index 9e3573eee25..dd518dc76ce 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -StrategyCheckpoint& StrategyCheckpoint::GetInstance() { +StrategyCheckpoint &StrategyCheckpoint::GetInstance() { static StrategyCheckpoint instance = StrategyCheckpoint(); return instance; } @@ -47,7 +47,7 @@ Status StrategyCheckpoint::RemoveCheckPoint() const { return FAILED; } -Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { +Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { if (strategy_map == nullptr) { MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; } @@ -82,18 +82,18 @@ Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } -Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { straspb::ParallelStrategyMap parallel_strategy_map; parallel_strategy_map.set_train_time(IntToUint(++current_train_time_)); - for (auto& node_stra : strategy_map) { - straspb::ParallelStrategyItem* parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); + for (auto &node_stra : strategy_map) { + straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); MS_EXCEPTION_IF_NULL(parallel_strategy_item); parallel_strategy_item->set_node_name(node_stra.first); - straspb::ParallelStrategys* parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); + straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); MS_EXCEPTION_IF_NULL(parallel_strategys); parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); - for (auto& dims : node_stra.second->GetInputDim()) { - straspb::ParallelStrategy* parallel_strategy = parallel_strategys->add_parallel_strategy(); + for (auto &dims : node_stra.second->GetInputDim()) { + straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); MS_EXCEPTION_IF_NULL(parallel_strategy); for (auto dim : dims) { parallel_strategy->add_dim(IntToUint(dim)); diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index b5d3626f532..c871ea6eef1 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -32,11 +32,11 @@ class StrategyCheckpoint { StrategyCheckpoint() : path_(DEFAULT_CHECKPOINT_PATH), current_train_time_(1) { train_times_ = 1; checkpoint_on_ = false; - const char* train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); + const char *train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); if (train_times_str != nullptr && std::stoi(train_times_str) > 0) { train_times_ = std::stoi(train_times_str); } - const char* checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON"); + const char *checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON"); if (checkpoint_on_str != nullptr) { checkpoint_on_ = (std::string(checkpoint_on_str) == "on"); } @@ -44,10 +44,10 @@ class StrategyCheckpoint { ~StrategyCheckpoint() = default; bool CheckPointExit() const; Status RemoveCheckPoint() const; - Status Load(StrategyMap* strategy_map); - Status Save(const StrategyMap& strategy_map); + Status Load(StrategyMap *strategy_map); + Status Save(const StrategyMap &strategy_map); - static StrategyCheckpoint& GetInstance(); + static StrategyCheckpoint &GetInstance(); int32_t GetTrainTimes() const { return train_times_; } int32_t GetCurrentTrainTime() const { return current_train_time_; } bool CheckPointOn() const { return checkpoint_on_; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc index b42ba302427..235ab00302d 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace parallel { -Status Arrangement::Init(const std::vector& array) { +Status Arrangement::Init(const std::vector &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; @@ -45,7 +45,7 @@ bool Arrangement::IsValidArrangement() { void Arrangement::ComputeSize() { size_ = 1; - for (auto& value : array_) { + for (auto &value : array_) { size_ *= value; } } @@ -84,7 +84,7 @@ std::vector Arrangement::GetFrontElementByValue(int32_t value) const { } std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft( - const std::vector& expand_list) const { + const std::vector &expand_list) const { if (expand_list.size() != GetDimSize()) { return nullptr; } @@ -108,7 +108,7 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft * array_ = [8, 4], * arrangement_list = [[4, 2], [2, 2]] */ -std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement& expand_shape) const { +std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { int32_t size = 1; uint32_t ind = 0; std::vector arrangement_list; @@ -140,7 +140,7 @@ std::shared_ptr> Arrangement::GetExpandShapeList(const } std::shared_ptr, Arrangement>> Arrangement::GetExpandShapeListPair( - const Arrangement& expand_shape) const { + const Arrangement &expand_shape) const { std::shared_ptr> expand_shape_list_ptr = GetExpandShapeList(expand_shape); if (expand_shape_list_ptr == nullptr) { return nullptr; @@ -148,7 +148,7 @@ std::shared_ptr, Arrangement>> Arrangement::G std::vector expand_num_list_shape; (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), std::back_inserter(expand_num_list_shape), - [](const Arrangement& arr) { return SizeToInt(arr.GetDimSize()); }); + [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); Arrangement expand_num_list; Status status = expand_num_list.Init(expand_num_list_shape); if (status != Status::SUCCESS) { @@ -169,7 +169,7 @@ std::vector Arrangement::ComputeReverseAccumulateSumInReverseOrder() co } std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLeft( - const std::vector& expand_list) const { + const std::vector &expand_list) const { if (expand_list.size() != GetDimSize()) { return nullptr; } @@ -191,7 +191,7 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLef return std::make_shared(arrangement_new); } -std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement& in2) const { +std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement &in2) const { std::vector in1_accum; Status status = ShapeToAccumulateProduct(array_, &in1_accum); if (status != Status::SUCCESS) { diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h b/mindspore/ccsrc/parallel/tensor_layout/arrangement.h index 2dc13038c10..ca71b05c915 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h +++ b/mindspore/ccsrc/parallel/tensor_layout/arrangement.h @@ -32,18 +32,18 @@ class Arrangement : public Array { public: Arrangement() : size_(1) {} ~Arrangement() override = default; - Status Init(const std::vector& array) override; + Status Init(const std::vector &array) override; int32_t size() const { return size_; } std::vector GetFrontElementByValue(int32_t value) const; - std::shared_ptr> GetExpandShapeList(const Arrangement& expand_shape) const; + std::shared_ptr> GetExpandShapeList(const Arrangement &expand_shape) const; std::vector ComputeReverseAccumulateSumInReverseOrder() const; std::shared_ptr GetExpandedShapeByExpandListReserveLeft( - const std::vector& expand_list) const; + const std::vector &expand_list) const; std::shared_ptr GetExpandedShapeByExpandListRemoveLeft( - const std::vector& expand_list) const; + const std::vector &expand_list) const; std::shared_ptr, Arrangement>> GetExpandShapeListPair( - const Arrangement& expand_shape) const; - std::shared_ptr GetUnifiedShape(const Arrangement& in2) const; + const Arrangement &expand_shape) const; + std::shared_ptr GetUnifiedShape(const Arrangement &in2) const; std::vector GetSqueezeIdx() const; Arrangement GetSqueezeArrangement() const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.cc b/mindspore/ccsrc/parallel/tensor_layout/array.cc index ba3858ae009..ef358e7cded 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/array.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/array.cc @@ -24,14 +24,14 @@ namespace parallel { std::string Array::ToString() const { std::ostringstream buffer; buffer << "[ "; - for (auto& element : array_) { + for (auto &element : array_) { buffer << std::to_string(element) + " "; } buffer << "]"; return buffer.str(); } -Status Array::Init(const std::vector& array) { +Status Array::Init(const std::vector &array) { array_ = array; return IsvalidArray() ? Status::SUCCESS : Status::FAILED; } @@ -54,7 +54,7 @@ int32_t Array::GetDimByReverseIdx(uint32_t idx) const { return array_[GetDimSize() - 1 - mod_idx]; } -bool Array::operator==(const Array& shape) const { +bool Array::operator==(const Array &shape) const { if (GetDimSize() != shape.GetDimSize()) { return false; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.h b/mindspore/ccsrc/parallel/tensor_layout/array.h index f7d9c3c673b..5aa3bdb1389 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/array.h +++ b/mindspore/ccsrc/parallel/tensor_layout/array.h @@ -31,13 +31,13 @@ class Array { Array() = default; virtual ~Array() = default; std::string ToString() const; - virtual Status Init(const std::vector& array); + virtual Status Init(const std::vector &array); bool IsvalidArray() const; std::vector array() const { return array_; } size_t GetDimSize() const { return array_.size(); } int32_t GetDimByIdx(uint32_t idx) const; int32_t GetDimByReverseIdx(uint32_t idx) const; - bool operator==(const Array& a1) const; + bool operator==(const Array &a1) const; protected: std::vector array_; diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc index 829c056fc2c..b5ca5ed60a1 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -Status ConstructOperator::Init(const RankList& dev_list, const Shape& dev_matrix_shape) { +Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) { dev_size_ = dev_matrix_shape.size(); dev_matrix_shape_ = dev_matrix_shape; dev_list_ = dev_list; @@ -46,7 +46,7 @@ Status ConstructOperator::ReshapeOP(Shape shape) { return Status::SUCCESS; } -Operator CreateStridedSliceOp(int32_t value, const Shape& begin, const Shape& end, const Shape& strides) { +Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &end, const Shape &strides) { ValuePtr attr_value = MakeValue(value); Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value); Attr attr_end_mask = std::make_pair(END_MASK, attr_value); @@ -230,7 +230,7 @@ Status ConstructOperator::AlltoAllOP(Args args) { return Status::SUCCESS; } -Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector* group) { +Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector *group) { MS_EXCEPTION_IF_NULL(group); CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h index cf6cff456a6..1a69638fb65 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h +++ b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h @@ -34,7 +34,7 @@ class ConstructOperator { const int32_t DEFAULT = 0; ConstructOperator() : dev_size_(0) {} ~ConstructOperator() = default; - Status Init(const RankList& dev_list, const Shape& dev_matrix_shape); + Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); Status ReshapeOP(Shape shape); Status StridedSliceOP(Args args); Status AllGatherOP(int32_t dev_dim); @@ -42,7 +42,7 @@ class ConstructOperator { Status ConcatOP(int32_t concat_dim); Status AlltoAllOP(Args args); Operator GetOperator() const { return op_; } - void UpdateTensorShape(const Shape& tensor_shape) { tensor_shape_ = tensor_shape; } + void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } private: Operator op_; @@ -50,7 +50,7 @@ class ConstructOperator { Shape tensor_shape_; RankList dev_list_; Shape dev_matrix_shape_; - Status CreateGroupByDim(size_t axis, std::vector* group); + Status CreateGroupByDim(size_t axis, std::vector *group); }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc index 190a5846baa..84c0580ba87 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc @@ -29,7 +29,7 @@ std::string LayoutTransfer::ToString() const { LayoutTransfer::~LayoutTransfer() = default; -Status LayoutTransfer::Init(const TensorLayout& from_in, const TensorLayout& to_in) { +Status LayoutTransfer::Init(const TensorLayout &from_in, const TensorLayout &to_in) { from_in_ = from_in; to_in_ = to_in; MS_LOG(DEBUG) << "LayoutTransfer " << this->ToString(); diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h index b05128f5b82..c4da4b728f9 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h @@ -28,7 +28,7 @@ class LayoutTransfer { LayoutTransfer() = default; virtual ~LayoutTransfer() = 0; std::string ToString() const; - Status Init(const TensorLayout& from_in, const TensorLayout& to_in); + Status Init(const TensorLayout &from_in, const TensorLayout &to_in); TensorLayout from_in() const { return from_in_; } TensorLayout to_in() const { return to_in_; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.cc b/mindspore/ccsrc/parallel/tensor_layout/map.cc index 320dbe6ebd8..669920fc446 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/map.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/map.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace parallel { -Status Map::Init(const std::vector& array) { +Status Map::Init(const std::vector &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; @@ -46,7 +46,7 @@ bool Map::IsValidMap() { std::vector sorted_array = array_; std::sort(sorted_array.begin(), sorted_array.end()); int32_t value = MAP_NONE; - for (auto& element : sorted_array) { + for (auto &element : sorted_array) { if (element == MAP_NONE) { continue; } @@ -78,7 +78,7 @@ int32_t Map::GetIndexByValue(int32_t value) const { /* * expand.size() should be equal to array_.size() */ -std::shared_ptr Map::ExpandMapByNone(const Arrangement& expand_num_list) const { +std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) const { if (expand_num_list.GetDimSize() != GetDimSize()) { return nullptr; } @@ -105,7 +105,7 @@ std::shared_ptr Map::ExpandMapByNone(const Arrangement& expand_num_list) co /* * expand.size() should be equal to array_.size() */ -std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement& expand_num_list) const { +std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { return nullptr; } @@ -126,7 +126,7 @@ std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement& expand_nu return map_new; } -std::shared_ptr> Map::ReMapVector(const std::vector& input_vector) const { +std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { if (GetMaxItem() >= static_cast(input_vector.size())) { return nullptr; } @@ -143,7 +143,7 @@ std::shared_ptr> Map::ReMapVector(const std::vector idx_list) const { - for (auto& value : idx_list) { + for (auto &value : idx_list) { if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { return false; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.h b/mindspore/ccsrc/parallel/tensor_layout/map.h index 3f839ef1989..8c8bba27750 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/map.h +++ b/mindspore/ccsrc/parallel/tensor_layout/map.h @@ -34,12 +34,12 @@ class Map : public Array { public: Map() = default; ~Map() override = default; - Status Init(const std::vector& array) override; + Status Init(const std::vector &array) override; int32_t GetMaxItem() const; int32_t GetIndexByValue(int32_t value) const; - std::shared_ptr ExpandMapByNone(const Arrangement& expand_num_list) const; - std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement& expand_num_list) const; - std::shared_ptr> ReMapVector(const std::vector& input_vector) const; + std::shared_ptr ExpandMapByNone(const Arrangement &expand_num_list) const; + std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; + std::shared_ptr> ReMapVector(const std::vector &input_vector) const; bool CheckNoneByIdxList(std::vector idx_list) const; Map SqueezeMapByIdxList(std::vector idx_list) const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc index ac768c19f95..946620ec4c2 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { -Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, +Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, bool is_cost_model) { in_tensor_map_ = tensor_layout.tensor_map(); dev_mat_ = tensor_layout.device_arrangement(); @@ -105,7 +105,7 @@ Status RedistributionOperatorInfer::InferSplitByAxis() { } if (in_dim == NONE && !std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { Args args = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) { MS_LOG(ERROR) << "Insert SplitByAxis Error!"; @@ -130,7 +130,7 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { } if (in_dim == NONE && std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); if (is_cost_model_) { diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h index 8fd953572a6..a96097a1d3c 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h @@ -40,7 +40,7 @@ class RedistributionOperatorInfer { public: const int NONE = -1; explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} - Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list, + Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, bool is_cost_model = false); ~RedistributionOperatorInfer() = default; OperatorList operator_list() const { return operator_list_; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc index 39a6bef92da..f6c90e9d466 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc @@ -104,7 +104,7 @@ std::shared_ptr ReshapeLayoutTransfer::ExchangeFromAndTo( } std::shared_ptr ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement& expand_shape) const { + const Arrangement &expand_shape) const { std::shared_ptr extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape); if (extend_tensor_shape_from_ptr == nullptr) { return nullptr; diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h index 8aae71631df..ed62cb59dad 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h @@ -33,7 +33,7 @@ class ReshapeLayoutTransfer : public LayoutTransfer { std::shared_ptr ExtendFromTensorShapeByExpandedTensorShape() const; std::shared_ptr ExtendToTensorShapeByExpandedTensorShape() const; std::shared_ptr ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement& expand_shape) const; + const Arrangement &expand_shape) const; std::shared_ptr ExchangeFromAndTo() const; private: diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc index a26627fb3ce..e8f208708cf 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc @@ -26,7 +26,7 @@ namespace parallel { * shape = [2, 8, 32] * shape_accum = [2, 2 * 8, 2 * 8 * 32] */ -Status ShapeToAccumulateProduct(const std::vector& shape, std::vector* shape_accum) { +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum) { MS_EXCEPTION_IF_NULL(shape_accum); shape_accum->clear(); int64_t size = 1; @@ -47,7 +47,7 @@ Status ShapeToAccumulateProduct(const std::vector& shape, std::vector& shape, std::vector* shape_accum) { +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum) { MS_EXCEPTION_IF_NULL(shape_accum); shape_accum->clear(); int64_t size = 1; @@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector& shape, std::v * shape = [2, 8, 32] * */ -Status AccumulateProductToShape(const std::vector& shape_accum, std::vector* shape) { +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape) { MS_EXCEPTION_IF_NULL(shape); shape->clear(); int64_t value = 1; @@ -92,7 +92,7 @@ Status AccumulateProductToShape(const std::vector& shape_accum, std::ve * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] * shape = [2, 8, 32] */ -Status AccumulateProductReverseToShape(const std::vector& shape_accum_reverse, std::vector* shape) { +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape) { MS_EXCEPTION_IF_NULL(shape); shape->clear(); int64_t value = 1; @@ -122,8 +122,8 @@ Status AccumulateProductReverseToShape(const std::vector& shape_accum_r * in2 = [8, 16] * *out = [2, 4, 8, 16] */ -Status UnifyAccumulateProduct(const std::vector& in1_accum, const std::vector& in2_accum, - std::vector* out_accum) { +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum) { MS_EXCEPTION_IF_NULL(out_accum); out_accum->clear(); auto in1_iter = in1_accum.begin(); @@ -159,7 +159,7 @@ Status UnifyAccumulateProduct(const std::vector& in1_accum, const std:: * in2 = [2, 16] * out = [2, 4, 4] */ -Status UnifyShape(const std::vector& in1, const std::vector& in2, std::vector* out) { +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out) { MS_EXCEPTION_IF_NULL(out); std::vector in1_accum; Status status = ShapeToAccumulateProduct(in1, &in1_accum); @@ -194,9 +194,9 @@ Status UnifyShape(const std::vector& in1, const std::vector& i * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] */ -Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, - const std::vector& expand_accum_reverse, - std::vector* out_accum_reverse) { +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse) { MS_EXCEPTION_IF_NULL(out_accum_reverse); out_accum_reverse->clear(); auto in_riter = in_accum_reverse.rbegin(); @@ -236,7 +236,7 @@ Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, * expand = [2, 4, 8] * out = [2, 4, 2, 4, 8] */ -Status ExpandShape(const std::vector& in, const std::vector& expand, std::vector* out) { +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out) { MS_EXCEPTION_IF_NULL(out); std::vector in_accum_reverse; Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h b/mindspore/ccsrc/parallel/tensor_layout/shape_util.h index e83156500c9..2ec21f3881e 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h +++ b/mindspore/ccsrc/parallel/tensor_layout/shape_util.h @@ -39,7 +39,7 @@ namespace parallel { * shape_accum = [2, 2 * 8, 2 * 8 * 32] * */ -Status ShapeToAccumulateProduct(const std::vector& shape, std::vector* shape_accum); +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum); /* * compute the accumulating product of all the values in shape from right to left, @@ -53,7 +53,7 @@ Status ShapeToAccumulateProduct(const std::vector& shape, std::vector& shape, std::vector* shape_accum); +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum); /* * compute the original shape from the accumulating product shape_accum, @@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector& shape, std::v * shape = [2, 8, 32] * */ -Status AccumulateProductToShape(const std::vector& shape_accum, std::vector* shape); +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape); /* * compute the original shape from the accumulating product shape_accum, @@ -83,7 +83,7 @@ Status AccumulateProductToShape(const std::vector& shape_accum, std::ve * shape = [2, 8, 32] * */ -Status AccumulateProductReverseToShape(const std::vector& shape_accum_reverse, std::vector* shape); +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape); /* * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, @@ -101,8 +101,8 @@ Status AccumulateProductReverseToShape(const std::vector& shape_accum_r * in2_accum = [8, 16] * out_accum = [2, 4, 8, 16] */ -Status UnifyAccumulateProduct(const std::vector& in1_accum, const std::vector& in2_accum, - std::vector* out_accum); +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum); /* * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] @@ -117,7 +117,7 @@ Status UnifyAccumulateProduct(const std::vector& in1_accum, const std:: * in2 = [2, 16] * out = [2, 4, 4] */ -Status UnifyShape(const std::vector& in1, const std::vector& in2, std::vector* out); +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out); /* * given two accumulate product in reverse order of in and expand, @@ -141,9 +141,9 @@ Status UnifyShape(const std::vector& in1, const std::vector& i * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] */ -Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, - const std::vector& expand_accum_reverse, - std::vector* out_accum_reverse); +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse); /* * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], @@ -165,7 +165,7 @@ Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, * expand = [2, 4, 8] * out = [2, 4, 2, 4, 8] */ -Status ExpandShape(const std::vector& in, const std::vector& expand, std::vector* out); +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h index 4a64ab472c1..43286317c57 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h @@ -32,9 +32,9 @@ using Shapes = std::vector; class TensorInfo { public: - TensorInfo(const TensorLayout& tensor_layout, Shape shape, Shape slice_shape) + TensorInfo(const TensorLayout &tensor_layout, Shape shape, Shape slice_shape) : tensor_layout_(tensor_layout), shape_(std::move(shape)), slice_shape_(std::move(slice_shape)) {} - explicit TensorInfo(const TensorLayout& tensor_layout) : tensor_layout_(tensor_layout) { + explicit TensorInfo(const TensorLayout &tensor_layout) : tensor_layout_(tensor_layout) { shape_ = tensor_layout.tensor_shape().array(); slice_shape_ = tensor_layout.slice_shape().array(); } @@ -44,7 +44,7 @@ class TensorInfo { TensorLayout tensor_layout() const { return tensor_layout_; } Shape slice_shape() const { return slice_shape_; } Shape shape() const { return shape_; } - void set_reduce_dim(const std::vector& dim) { reduce_dim_ = dim; } + void set_reduce_dim(const std::vector &dim) { reduce_dim_ = dim; } std::vector reduce_dim() const { return reduce_dim_; } private: diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc index 5fbd04431cb..f3498065f29 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc @@ -45,8 +45,8 @@ std::string TensorLayout::OriginToString() const { return buffer.str(); } -Status TensorLayout::Init(const Arrangement& device_arrangement, const Map& tensor_map, - const Arrangement& tensor_shape) { +Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map, + const Arrangement &tensor_shape) { device_arrangement_origin_ = device_arrangement; tensor_map_origin_ = tensor_map; tensor_shape_origin_ = tensor_shape; @@ -64,8 +64,8 @@ Status TensorLayout::Init(const Arrangement& device_arrangement, const Map& tens } } -Status TensorLayout::InitFromVector(const std::vector& device_arrangement, - const std::vector& tensor_map, const std::vector& tensor_shape) { +Status TensorLayout::InitFromVector(const std::vector &device_arrangement, + const std::vector &tensor_map, const std::vector &tensor_shape) { if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { return FAILED; } @@ -124,7 +124,7 @@ void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { if (idx != -1) { tensor_map_shape[static_cast(idx)] = -1; } - for (auto& value : tensor_map_shape) { + for (auto &value : tensor_map_shape) { if (value >= dev_num_left - 1 - static_cast(i)) { value--; } @@ -153,7 +153,7 @@ int32_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint32_t idx) const { return device_arrangement_.GetDimByIdx(static_cast(GetSliceDeviceDimensionByTensorDimensionIndex(idx))); } -std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement& expanded_shape) const { +std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const { std::shared_ptr expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape); if (expanded_arrangement_ptr == nullptr) { return nullptr; @@ -174,7 +174,7 @@ std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement& * => * out_device_arrangement = [8, 2, 2] */ -std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement& tensor_shape) const { +std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const { std::shared_ptr> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape); if (expand_list_ptr == nullptr) { return nullptr; @@ -204,7 +204,7 @@ std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(con * out_tensor_map = [1, -1, 0, -1], */ std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement& expanded_shape) const { + const Arrangement &expanded_shape) const { std::shared_ptr, Arrangement>> expand_list_pair_ptr = tensor_shape_.GetExpandShapeListPair(expanded_shape); if (expand_list_pair_ptr == nullptr) { @@ -259,7 +259,7 @@ std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDevice * out_tensor_map = [0, 2, 1], * out_tensor_shape = [512, 4, 256] */ -std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement& expanded_arrangement) const { +std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const { std::shared_ptr, Arrangement>> expand_list_pair_ptr = device_arrangement_.GetExpandShapeListPair(expanded_arrangement); if (expand_list_pair_ptr == nullptr) { @@ -287,7 +287,7 @@ std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrang return std::make_shared(tensor_layout_new); } -bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement& expand_shape) const { +bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { std::vector in_expand_shape_shape; Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); if (status != Status::SUCCESS) { @@ -296,7 +296,7 @@ bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement& expand_shape) con return (in_expand_shape_shape == tensor_shape_.array()); } -std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement& expand_shape) const { +std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { std::vector in_expand_shape_shape; Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); if (status != Status::SUCCESS) { @@ -345,7 +345,7 @@ Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { return Status::SUCCESS; } -bool TensorLayout::operator==(const TensorLayout& t1) const { +bool TensorLayout::operator==(const TensorLayout &t1) const { return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); } diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h index e6ddc2a708e..f51ed4e3e0a 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h @@ -37,9 +37,9 @@ class TensorLayout { std::string ToString() const; std::string StandardToString() const; std::string OriginToString() const; - Status Init(const Arrangement& device_arrangement, const Map& tensor_map, const Arrangement& tensor_shape); - Status InitFromVector(const std::vector& device_arrangement, const std::vector& tensor_map, - const std::vector& tensor_shape); + Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); + Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, + const std::vector &tensor_shape); Arrangement device_arrangement() const { return device_arrangement_; } @@ -49,25 +49,25 @@ class TensorLayout { Map origin_tensor_map() const { return tensor_map_origin_; } - std::shared_ptr ExpandTensorShape(const Arrangement& expanded_shape) const; + std::shared_ptr ExpandTensorShape(const Arrangement &expanded_shape) const; - std::shared_ptr ExpandDeviceArrangement(const Arrangement& expanded_arrangement) const; + std::shared_ptr ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; - bool IsSameTensorShape(const TensorLayout& tensor_layout) const { + bool IsSameTensorShape(const TensorLayout &tensor_layout) const { return (tensor_shape_ == tensor_layout.tensor_shape()); } - bool IsSameDeviceArrangement(const TensorLayout& tensor_layout) const { + bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { return (device_arrangement_ == tensor_layout.device_arrangement()); } - bool IsSameTensorMap(const TensorLayout& tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } + bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } - bool operator==(const TensorLayout& t1) const; + bool operator==(const TensorLayout &t1) const; - bool TensorShapeCanBeExpanded(const Arrangement& expanded_shape) const; + bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; - std::shared_ptr ComputeExpandedTensorShape(const Arrangement& expand_shape) const; + std::shared_ptr ComputeExpandedTensorShape(const Arrangement &expand_shape) const; Arrangement slice_shape() const; @@ -77,8 +77,8 @@ class TensorLayout { private: std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement& expanded_shape) const; - std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement& tensor_shape) const; + const Arrangement &expanded_shape) const; + std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; bool IsValidTensorLayout() const; void RemoveElementEqualToOneInDeviceArrangement(); int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc index 460cd9d1bd5..7824c21f3d5 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace parallel { -Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list) { +Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) { from_origin_ = from; to_origin_ = to; if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) { @@ -87,9 +87,9 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL std::make_pair(operator_vector, output_info_vector)); } -Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout, - OperatorVector* const operator_vector, - OutPutInfoVector* const output_info_vector) { +Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, + OutPutInfoVector *const output_info_vector) { MS_EXCEPTION_IF_NULL(operator_vector); MS_EXCEPTION_IF_NULL(output_info_vector); ConstructOperator constructor; @@ -144,7 +144,7 @@ Status TensorRedistribution::ComputeCost() { return Status::FAILED; } // Compute redistribution communication cost and computation cost - for (auto& op_cost : operator_list_) { + for (auto &op_cost : operator_list_) { OperatorR op = op_cost.first; Shape slice_shape = op_cost.second; double prod = diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h index 71d4a02701f..e7800909c5f 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h @@ -46,7 +46,7 @@ class TensorRedistribution { memory_cost_(0.0), construct_op_flag_(construct_op_flag), keep_reshape_(keep_reshape) {} - Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); + Status Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list); ~TensorRedistribution() = default; RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); OperatorList operator_list() const { return operator_list_; } @@ -59,8 +59,8 @@ class TensorRedistribution { double memory_cost() const { return memory_cost_; } private: - Status InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout, - OperatorVector* const operator_vector, OutPutInfoVector* const output_info_vector); + Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); TensorLayout from_origin_; TensorLayout to_origin_; diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 3e0f8804e71..e8723e66a4b 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -41,8 +41,8 @@ using CompileGraphs = compile::CompileGraphs; using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec, bool clear) { +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear) { MS_LOG(DEBUG) << "AbstractAnalyze start"; auto engine = res->engine(); MS_EXCEPTION_IF_NULL(engine); @@ -50,9 +50,9 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraph auto manager = res->manager(); MS_EXCEPTION_IF_NULL(manager); engine->Clear(); - for (auto& node : manager->all_nodes()) { + for (auto &node : manager->all_nodes()) { MS_EXCEPTION_IF_NULL(node); - const AbstractBasePtr& prev_inferred = node->abstract(); + const AbstractBasePtr &prev_inferred = node->abstract(); // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { node->set_abstract(nullptr); @@ -65,8 +65,8 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraph return ret; } -FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AnalysisContextPtr& context) { +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context) { MS_LOG(DEBUG) << "ProgramSpecialize start"; abstract::ProgramSpecializer spc(res->engine()); FuncGraphPtr result = spc.Run(func_graph, context); @@ -77,8 +77,8 @@ FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_ return result; } -FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec) { +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec) { MS_LOG(DEBUG) << "Renormalize start"; #ifdef ENABLE_PROFILE double t1 = GetTime(); @@ -98,7 +98,7 @@ FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, return ret; } -bool ParseAction(const ResourcePtr& res) { +bool ParseAction(const ResourcePtr &res) { if (!res->input()) { MS_LOG(EXCEPTION) << "Parse error"; } @@ -129,11 +129,11 @@ bool ParseAction(const ResourcePtr& res) { // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} // all obj_map's graph shared base_graph -bool CombineLikeGraphs(const ResourcePtr&) { - auto& obj_map = parse::data_converter::GetObjGraphs(); +bool CombineLikeGraphs(const ResourcePtr &) { + auto &obj_map = parse::data_converter::GetObjGraphs(); for (auto it : obj_map) { - auto& graphs = it.second; + auto &graphs = it.second; MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size(); auto fg = graphs[0]; FuncGraphPtrList func_graphs = {fg}; @@ -147,7 +147,7 @@ bool CombineLikeGraphs(const ResourcePtr&) { continue; } auto mng = Manage(base_graph, false); - for (auto& fv : fg->paramter_obj_nodes()) { + for (auto &fv : fg->paramter_obj_nodes()) { TraceManager::DebugTrace(std::make_shared(fv->debug_info())); auto param = base_graph->add_parameter(); TraceManager::EndTrace(); @@ -156,11 +156,11 @@ bool CombineLikeGraphs(const ResourcePtr&) { } MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); - for (auto& g : graphs) { + for (auto &g : graphs) { auto fvs = g->paramter_obj_nodes(); std::vector new_node_inputs; new_node_inputs.push_back(NewValueNode(base_graph)); - for (auto& p : g->parameters()) { + for (auto &p : g->parameters()) { AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p); new_node_inputs.push_back(para_after_cast); } @@ -174,7 +174,7 @@ bool CombineLikeGraphs(const ResourcePtr&) { return true; } -bool SymbolResolveAction(const ResourcePtr& res) { +bool SymbolResolveAction(const ResourcePtr &res) { if (res->manager() == nullptr) { MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; } @@ -195,7 +195,7 @@ bool SymbolResolveAction(const ResourcePtr& res) { return succ; } -bool InferenceOptPrepareAction(const ResourcePtr& res) { +bool InferenceOptPrepareAction(const ResourcePtr &res) { if (res->manager() == nullptr) { MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; } @@ -205,7 +205,7 @@ bool InferenceOptPrepareAction(const ResourcePtr& res) { return InferenceOptPreparePass(res); } -bool AbstractSpecializeAction(const ResourcePtr& res) { +bool AbstractSpecializeAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "AbstractSpecialize error"; } @@ -215,7 +215,7 @@ bool AbstractSpecializeAction(const ResourcePtr& res) { // suppose that there is not KeywordArgument for the top graph // get the hyper parameter - for (const auto& param : func_graph->parameters()) { + for (const auto ¶m : func_graph->parameters()) { auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { AbstractBasePtr ptr = @@ -236,8 +236,8 @@ bool AbstractSpecializeAction(const ResourcePtr& res) { return true; } -bool OptimizeAction(const ResourcePtr& res, const std::vector& passes) { - for (auto& pass : passes) { +bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) { + for (auto &pass : passes) { WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res]() { MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; auto result = pass.second(res); @@ -251,11 +251,11 @@ bool OptimizeAction(const ResourcePtr& res, const std::vector& passes) return true; } -bool GeOptimizeAction(const ResourcePtr& res) { return OptimizeAction(res, kGePasses); } +bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } -bool VmOptimizeAction(const ResourcePtr& res) { return OptimizeAction(res, kVmPasses); } +bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } -bool TaskEmitAction(const ResourcePtr& res) { +bool TaskEmitAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "TaskEmit args error"; } @@ -271,7 +271,7 @@ bool TaskEmitAction(const ResourcePtr& res) { return true; } -bool ExecuteAction(const ResourcePtr& res) { +bool ExecuteAction(const ResourcePtr &res) { if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is()) { MS_LOG(EXCEPTION) << "Execute args error"; } @@ -291,11 +291,11 @@ bool ExecuteAction(const ResourcePtr& res) { // that will result in a syncronization error due to different executing order. // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, // the final solution will be proposed later as a parallel feature. -bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& res) { - auto& node_users = res->manager()->node_users(); - auto& users = node_users[value_node]; +bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) { + auto &node_users = res->manager()->node_users(); + auto &users = node_users[value_node]; auto used_by_keep_value_prim = - std::any_of(users.begin(), users.end(), [](const std::pair& user) -> bool { + std::any_of(users.begin(), users.end(), [](const std::pair &user) -> bool { MS_EXCEPTION_IF_NULL(user.first); auto cnode = user.first->cast(); if (cnode == nullptr) { @@ -312,7 +312,7 @@ bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& r return used_by_keep_value_prim; } -bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { +bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "Remove value node duplications error."; } @@ -322,7 +322,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { auto value_nodes = manager->valuenodes()[func_graph]; HashCache hash_cache; HashValue hashes; - for (const auto& value_pair : value_nodes) { + for (const auto &value_pair : value_nodes) { if (KeepValueNodeDuplication(value_pair.first, res)) { continue; } @@ -331,7 +331,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { return true; } -bool ValidateAction(const ResourcePtr& res) { return ValidatePass(res); } +bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } static std::vector CommonPipeline() { std::vector actions; diff --git a/mindspore/ccsrc/pipeline/action.h b/mindspore/ccsrc/pipeline/action.h index 159e494a964..8a651c00383 100644 --- a/mindspore/ccsrc/pipeline/action.h +++ b/mindspore/ccsrc/pipeline/action.h @@ -30,22 +30,22 @@ extern const char kMsConvert[]; namespace pipeline { using ActionItem = std::pair>; -bool ParseAction(const ResourcePtr& res); -bool SymbolResolveAction(const ResourcePtr& res); -bool AbstractSpecializeAction(const ResourcePtr& res); -bool GeOptimizeAction(const ResourcePtr& res); -bool VmOptimizeAction(const ResourcePtr& res); -bool TaskEmitAction(const ResourcePtr& res); -bool ExecuteAction(const ResourcePtr& res); +bool ParseAction(const ResourcePtr &res); +bool SymbolResolveAction(const ResourcePtr &res); +bool AbstractSpecializeAction(const ResourcePtr &res); +bool GeOptimizeAction(const ResourcePtr &res); +bool VmOptimizeAction(const ResourcePtr &res); +bool TaskEmitAction(const ResourcePtr &res); +bool ExecuteAction(const ResourcePtr &res); std::vector GePipeline(); std::vector VmPipeline(); -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec, bool clear = false); -FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AnalysisContextPtr& context); -FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec); +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear = false); +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context); +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/base.h b/mindspore/ccsrc/pipeline/base.h index 30524e84f68..8ca153f45b4 100644 --- a/mindspore/ccsrc/pipeline/base.h +++ b/mindspore/ccsrc/pipeline/base.h @@ -37,7 +37,7 @@ struct ExecutorInfo { using ExecutorInfoPtr = std::shared_ptr; -inline std::string GetPhasePrefix(const std::string& phase) { +inline std::string GetPhasePrefix(const std::string &phase) { auto pos = phase.find('.'); if (pos == std::string::npos) { MS_LOG(EXCEPTION) << "Phase has no . for prefix" << phase; @@ -45,7 +45,7 @@ inline std::string GetPhasePrefix(const std::string& phase) { return phase.substr(0, pos); } -inline std::string GetFilePathName(const std::string& file_name) { +inline std::string GetFilePathName(const std::string &file_name) { std::ostringstream oss; auto ms_context = MsContext::GetInstance(); if (ms_context == nullptr) { diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index b709199c870..86e6d436b7c 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -53,10 +53,10 @@ PYBIND11_MODULE(_c_expression, m) { (void)py::class_>(*m, "MetaFuncGraph_") .def_readonly(mindspore::PYTHON_METAFUNCGRAPH_FLAG, &mindspore::MetaFuncGraph::parse_info_) - .def(py::init()); + .def(py::init()); auto fns = mindspore::PybindDefineRegister::AllFuncs(); - for (auto& item : fns) { + for (auto &item : fns) { item.second(&m); } @@ -288,7 +288,7 @@ PYBIND11_MODULE(_c_expression, m) { }}); (void)py::class_>(m, "EventWriter_") - .def(py::init()) + .def(py::init()) .def("GetFileName", &EventWriter::GetFileName, "Get the file name.") .def("Open", &EventWriter::Open, "Open the write file.") .def("Write", &EventWriter::Write, "Write the serialize event.") diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc index d25a202afc4..861fc0eda88 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/parse/data_converter.cc @@ -38,7 +38,7 @@ using Tensor = mindspore::tensor::Tensor; using TensorPtr = mindspore::tensor::TensorPtr; namespace { -bool ConvertTuple(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python tuple"; py::tuple tuple = obj.cast(); std::vector value_list; @@ -55,7 +55,7 @@ bool ConvertTuple(const py::object& obj, ValuePtr* const data, bool use_signatur return true; } -bool ConvertList(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python list"; py::list list = obj.cast(); @@ -72,7 +72,7 @@ bool ConvertList(const py::object& obj, ValuePtr* const data, bool use_signature return true; } -bool ConvertCellList(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting cell list"; py::sequence list = obj; std::vector value_list; @@ -88,7 +88,7 @@ bool ConvertCellList(const py::object& obj, ValuePtr* const data, bool use_signa return true; } -bool ConvertDict(const py::object& obj, ValuePtr* data, bool use_signature) { +bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { MS_LOG(DEBUG) << "Converting python dict"; py::dict dict_values = obj.cast(); @@ -109,14 +109,14 @@ bool ConvertDict(const py::object& obj, ValuePtr* data, bool use_signature) { return true; } -void ConvertNameSpace(const py::object& obj, ValuePtr* const data) { +void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting python module"; py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); } -void ConvertDataClass(py::object obj, ValuePtr* const data) { +void ConvertDataClass(py::object obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting dataclass"; // Maybe the obj is dataclass define auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); @@ -124,7 +124,7 @@ void ConvertDataClass(py::object obj, ValuePtr* const data) { *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); } -bool ConvertPrimitive(py::object obj, ValuePtr* const data, bool use_signature = false) { +bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { MS_LOG(DEBUG) << "Converting primitive object"; // need check the primitive is class type or instance @@ -155,7 +155,7 @@ bool ConvertPrimitive(py::object obj, ValuePtr* const data, bool use_signature = return true; } -bool ConvertMetaFuncGraph(const py::object& obj, ValuePtr* const data, bool use_signature = false) { +bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; auto meta = obj.cast(); if (meta == nullptr) { @@ -170,7 +170,7 @@ bool ConvertMetaFuncGraph(const py::object& obj, ValuePtr* const data, bool use_ return true; } -bool ConvertDataType(const py::object& obj, ValuePtr* const data) { +bool ConvertDataType(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting type object"; auto typeptr = obj.cast(); if (typeptr == nullptr) { @@ -181,7 +181,7 @@ bool ConvertDataType(const py::object& obj, ValuePtr* const data) { return true; } -bool ConvertTensor(const py::object& obj, ValuePtr* const data) { +bool ConvertTensor(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting tensor object"; auto m_tensor = obj.cast(); @@ -193,7 +193,7 @@ bool ConvertTensor(const py::object& obj, ValuePtr* const data) { return true; } -bool ConvertOtherObj(py::object obj, ValuePtr* const data) { +bool ConvertOtherObj(py::object obj, ValuePtr *const data) { auto obj_type = data_converter::GetObjType(obj); MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { @@ -244,7 +244,7 @@ bool ConvertOtherObj(py::object obj, ValuePtr* const data) { } } // namespace -bool ConvertData(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { // check parameter valid if (data == nullptr) { MS_LOG(ERROR) << "Data is null pointer"; @@ -295,7 +295,7 @@ bool ConvertData(const py::object& obj, ValuePtr* const data, bool use_signature } // convert data to graph -FuncGraphPtr ConvertToFuncGraph(const py::object& obj, const std::string& python_mod_get_parse_method) { +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { std::vector results = data_converter::GetObjKey(obj); std::string obj_id = results[0] + python_mod_get_parse_method; std::string obj_key = results[1]; @@ -331,25 +331,25 @@ static std::unordered_map object_map_ = std::unordered_map> object_graphs_map_ = std::unordered_map>(); -void SetObjGraphValue(const std::string& obj_key, const FuncGraphPtr& data) { +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { object_graphs_map_[obj_key].push_back(data); MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); } -const std::unordered_map>& GetObjGraphs() { +const std::unordered_map> &GetObjGraphs() { MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); return object_graphs_map_; } -void CacheObjectValue(const std::string& obj_key, const Any& data) { object_map_[obj_key] = data; } -bool GetObjectValue(const std::string& obj_key, Any* const data) { +void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } +bool GetObjectValue(const std::string &obj_key, Any *const data) { if (object_map_.count(obj_key)) { *data = object_map_[obj_key]; return true; } return false; } -std::vector GetObjKey(const py::object& obj) { +std::vector GetObjKey(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); if (obj_tuple.size() != 2) { @@ -359,7 +359,7 @@ std::vector GetObjKey(const py::object& obj) { } // get obj detail type -ResolveTypeDef GetObjType(const py::object& obj) { +ResolveTypeDef GetObjType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto obj_type = ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); @@ -367,7 +367,7 @@ ResolveTypeDef GetObjType(const py::object& obj) { } // get class instance detail type -ClassInstanceTypeDef GetClassInstanceType(const py::object& obj) { +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto class_type = ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast()); @@ -375,14 +375,14 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object& obj) { } // check the object is Cell Instance -bool IsCellInstance(const py::object& obj) { +bool IsCellInstance(const py::object &obj) { auto class_type = GetClassInstanceType(obj); bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); return isCell; } // create the python class instance -py::object CreatePythonObject(const py::object& type, const py::tuple& params) { +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object obj; if (params.size() == 0) { @@ -395,7 +395,7 @@ py::object CreatePythonObject(const py::object& type, const py::tuple& params) { // Generate an appropriate name and set to graph debuginfo // character <> can not used in the dot file, so change to another symbol -void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name) { +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph->debug_info()); // set detail name info of function @@ -412,7 +412,7 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name) func_graph->debug_info()->set_full_name(oss.str()); } -ValuePtr PyDataToValue(const py::object& obj) { +ValuePtr PyDataToValue(const py::object &obj) { py::object to_convert = obj; if (py::hasattr(obj, "__parameter__")) { to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); @@ -431,7 +431,7 @@ void ClearObjectCache() { static std::unordered_map g_dataClassToClass = {}; // parse dataclass to mindspore Class type -ClassPtr ParseDataClass(const py::object& cls_obj) { +ClassPtr ParseDataClass(const py::object &cls_obj) { std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); std::string cls = cls_module + "." + cls_name; @@ -443,7 +443,7 @@ ClassPtr ParseDataClass(const py::object& cls_obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); ClassAttrVector attributes; py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); - for (auto& item : names) { + for (auto &item : names) { TypePtr type_value = item.second.cast(); MS_EXCEPTION_IF_NULL(type_value); MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; @@ -452,7 +452,7 @@ ClassPtr ParseDataClass(const py::object& cls_obj) { std::unordered_map methods_map; py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); - for (auto& item : methods) { + for (auto &item : methods) { std::string fun_name = item.first.cast(); py::object obj = py::cast(item.second); std::shared_ptr method_obj = std::make_shared(obj, fun_name); diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.h b/mindspore/ccsrc/pipeline/parse/data_converter.h index 658360bcee5..a8918fa60c1 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.h +++ b/mindspore/ccsrc/pipeline/parse/data_converter.h @@ -32,25 +32,25 @@ namespace mindspore { namespace parse { // data convert for parse namespace data_converter { -void CacheObjectValue(const std::string& obj_key, const Any& data); -bool GetObjectValue(const std::string& obj_key, Any* const data); +void CacheObjectValue(const std::string &obj_key, const Any &data); +bool GetObjectValue(const std::string &obj_key, Any *const data); -void SetObjGraphValue(const std::string& obj_key, const FuncGraphPtr& data); +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); -const std::unordered_map>& GetObjGraphs(); +const std::unordered_map> &GetObjGraphs(); -std::vector GetObjKey(const py::object& obj); -ResolveTypeDef GetObjType(const py::object& obj); -ClassInstanceTypeDef GetClassInstanceType(const py::object& obj); +std::vector GetObjKey(const py::object &obj); +ResolveTypeDef GetObjType(const py::object &obj); +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); -bool IsCellInstance(const py::object& obj); -py::object CreatePythonObject(const py::object& type, const py::tuple& params); -void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name); -ValuePtr PyDataToValue(const py::object& obj); +bool IsCellInstance(const py::object &obj); +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); +ValuePtr PyDataToValue(const py::object &obj); void ClearObjectCache(); } // namespace data_converter -ClassPtr ParseDataClass(const py::object& cls_obj); +ClassPtr ParseDataClass(const py::object &cls_obj); void CleanDataClassToClassMap(); diff --git a/mindspore/ccsrc/pipeline/parse/function_block.cc b/mindspore/ccsrc/pipeline/parse/function_block.cc index 423e76c1d87..156f727b9e4 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/parse/function_block.cc @@ -28,21 +28,21 @@ namespace mindspore { namespace parse { -FunctionBlock::FunctionBlock(const Parser& parser) : parser_(parser) { +FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { func_graph_ = std::make_shared(); matured_ = false; } -void FunctionBlock::AddPrevBlock(const FunctionBlockPtr& block) { prev_blocks_.push_back(block.get()); } +void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } // write variable records the variable name to corresponding node -void FunctionBlock::WriteVariable(const std::string& var_name, const AnfNodePtr& node) { +void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { MS_LOG(DEBUG) << "" << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); vars_[var_name] = node; } // read variable from predecessors -AnfNodePtr FunctionBlock::ReadVariable(const std::string& var) { +AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { // get var node if it is found if (vars_.count(var)) { AnfNodePtr node = vars_[var]; @@ -82,7 +82,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string& var) { } // Resolve Ast operator node -AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object& op) { +AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { auto ast = parser_.ast(); MS_EXCEPTION_IF_NULL(ast); TraceGuard trace_guard(parser_.GetLocation(op)); @@ -105,7 +105,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { } // Make a resolve node for symbol string -AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string& value) { +AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { if (value.compare(0, strlen("self."), "self.") == 0) { auto start = value.find_first_of('.') + 1; if (start >= value.size()) { @@ -122,14 +122,14 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string& value) { return MakeResolve(name_space, symbol); } -AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string& value) { +AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) { py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value); NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]); SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); return MakeResolve(name_space, symbol); } -AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr& name_space, const SymbolPtr& resolve_symbol) { +AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) { MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , " << ((std::string)resolve_symbol->symbol()); ValueNodePtr module_node = NewValueNode(name_space); @@ -139,10 +139,10 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr& name_space, const Symb } // add input for the block's phi parameter -void FunctionBlock::SetPhiArgument(const ParameterPtr& phi) { +void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; - for (auto& pred : prev_blocks_) { + for (auto &pred : prev_blocks_) { MS_EXCEPTION_IF_NULL(pred); MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); AnfNodePtr arg_node = pred->ReadVariable(var); @@ -161,9 +161,9 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr& phi) { } } -AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string& var, const ParameterPtr& phi) { +AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { AnfNodePtr arg_node = nullptr; - for (auto& prev : prev_blocks_) { + for (auto &prev : prev_blocks_) { MS_EXCEPTION_IF_NULL(prev); AnfNodePtr temp_node = prev->ReadVariable(var); MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var @@ -204,7 +204,7 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string& var, const Parame // 2. it's costly to iterate the graph to replace the phi for each phi. // Args : // phi : This parameter node is functioning as a phi node. -void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { +void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { MS_EXCEPTION_IF_NULL(phi); std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); @@ -221,15 +221,15 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { removable_phis_[phi] = arg_node; // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. - for (auto& prev : prev_blocks_) { + for (auto &prev : prev_blocks_) { MS_EXCEPTION_IF_NULL(prev); if (!prev->matured_) { continue; } - for (auto& phi_iter : prev->removable_phis_) { + for (auto &phi_iter : prev->removable_phis_) { MS_EXCEPTION_IF_NULL(phi_iter.second); if (phi_iter.second->isa()) { - const auto& param = phi_iter.second->cast(); + const auto ¶m = phi_iter.second->cast(); if (param == phi) { MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); @@ -243,8 +243,8 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { // A block should be marked matured if its predecessor blocks have been processed void FunctionBlock::Mature() { - const auto& graphParamVec = func_graph_->parameters(); - for (auto& paramItr : graphParamVec) { + const auto &graphParamVec = func_graph_->parameters(); + for (auto ¶mItr : graphParamVec) { MS_EXCEPTION_IF_NULL(paramItr); ParameterPtr param = paramItr->cast(); if (phi_nodes_.find(param) != phi_nodes_.cend()) { @@ -255,7 +255,7 @@ void FunctionBlock::Mature() { } // Force the conditIon node to bool using bool operation -CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr& cond) { +CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) { TraceManager::DebugTrace(std::make_shared(cond->debug_info())); CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); TraceManager::EndTrace(); @@ -263,7 +263,7 @@ CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr& cond) { } // Perform a jump from this block to target block -void FunctionBlock::Jump(const FunctionBlockPtr& target_block, AnfNodePtr node) { +void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) { if (func_graph()->get_return() != nullptr) { MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); @@ -283,8 +283,8 @@ void FunctionBlock::Jump(const FunctionBlockPtr& target_block, AnfNodePtr node) // Perform a conditional jump using switch operation. // The first CNode select graph with condition, and than execute this graph -void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& true_block, - const FunctionBlockPtr& false_block) { +void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block) { if (func_graph()->get_return() != nullptr) { MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); @@ -297,15 +297,15 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& InsertDependItemsBeforeReturn(); } -void FunctionBlock::SetStateAssgin(const AnfNodePtr& target, const std::string& readid) { +void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { state_assign_[target] = readid; } -void FunctionBlock::AddAutoDepend(const AnfNodePtr& target) { auto_depends_.push_back(target); } +void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } void FunctionBlock::InsertDependItemsBeforeReturn() { if (!prev_blocks_.empty()) { - for (auto& prev_block : prev_blocks_) { + for (auto &prev_block : prev_blocks_) { MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get(); } } @@ -324,14 +324,14 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { AnfNodePtr state = nullptr; std::vector vec_states; vec_states.emplace_back(make_tuple_op); - for (auto& item : state_assign_) { + for (auto &item : state_assign_) { auto source = ReadVariable(item.second); auto refkey = func_graph()->NewCNode({get_refkey_op, item.first}); auto assign = func_graph()->NewCNode({assign_op, refkey, source}); MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; vec_states.emplace_back(assign); } - for (auto& item : auto_depends_) { + for (auto &item : auto_depends_) { MS_LOG(DEBUG) << "auto_depends " << item->ToString(); vec_states.emplace_back(item); } diff --git a/mindspore/ccsrc/pipeline/parse/function_block.h b/mindspore/ccsrc/pipeline/parse/function_block.h index 0be6e472f84..e7842903ee2 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/parse/function_block.h @@ -43,47 +43,47 @@ using FunctionBlockPtr = std::shared_ptr; // the original source code. class FunctionBlock : public std::enable_shared_from_this { public: - explicit FunctionBlock(const Parser& parser); + explicit FunctionBlock(const Parser &parser); virtual ~FunctionBlock() {} FuncGraphPtr func_graph() { return func_graph_; } - void WriteVariable(const std::string& var_name, const AnfNodePtr& node); - AnfNodePtr ReadVariable(const std::string& var_name); - void AddPrevBlock(const FunctionBlockPtr& block); - void SetPhiArgument(const ParameterPtr& phi); - void CollectRemovablePhi(const ParameterPtr& phi); + void WriteVariable(const std::string &var_name, const AnfNodePtr &node); + AnfNodePtr ReadVariable(const std::string &var_name); + void AddPrevBlock(const FunctionBlockPtr &block); + void SetPhiArgument(const ParameterPtr &phi); + void CollectRemovablePhi(const ParameterPtr &phi); // A block is matured if all its predecessors is generated void Mature(); - CNodePtr ForceToBoolNode(const AnfNodePtr& cond); - void Jump(const FunctionBlockPtr& block, AnfNodePtr node); - AnfNodePtr SearchReplaceNode(const std::string& var, const ParameterPtr& phi); - void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& trueBlock, const FunctionBlockPtr& falseBlock); + CNodePtr ForceToBoolNode(const AnfNodePtr &cond); + void Jump(const FunctionBlockPtr &block, AnfNodePtr node); + AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); + void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock); // record the assign statement of self.xx weight parameter ,which will use state_setitem op - void SetStateAssgin(const AnfNodePtr& target, const std::string& readid); - void AddAutoDepend(const AnfNodePtr& target); + void SetStateAssgin(const AnfNodePtr &target, const std::string &readid); + void AddAutoDepend(const AnfNodePtr &target); void InsertDependItemsBeforeReturn(); - void AddGlobalVar(const std::string& var_name) { (void)global_vars_.insert(var_name); } - bool IsGlobalVar(const std::string& var_name) { return global_vars_.find(var_name) != global_vars_.end(); } - AnfNodePtr MakeResolveAstOp(const py::object& op); + void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } + bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } + AnfNodePtr MakeResolveAstOp(const py::object &op); AnfNodePtr MakeResolveClassMember(std::string attr); - AnfNodePtr MakeResolveSymbol(const std::string& value); - AnfNodePtr MakeResolveOperation(const std::string& value); - AnfNodePtr MakeResolve(const std::shared_ptr& name_space, const std::shared_ptr& resolve_symbol); - const std::unordered_map& removable_phis() const { return removable_phis_; } + AnfNodePtr MakeResolveSymbol(const std::string &value); + AnfNodePtr MakeResolveOperation(const std::string &value); + AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); + const std::unordered_map &removable_phis() const { return removable_phis_; } private: // block graph FuncGraphPtr func_graph_; // the block's parser - const Parser& parser_; + const Parser &parser_; // A block is matured if all its prev_blocks is processed bool matured_; // store the nest-level block // refer to comments in Parser::func_block_list_; - std::vector prev_blocks_; + std::vector prev_blocks_; // store args and variable's node std::map vars_; @@ -93,7 +93,7 @@ class FunctionBlock : public std::enable_shared_from_this { // jumps map the successor block and the function call that perform jump // refer to comments in Parser::func_block_list_ that how to break the cyclic reference - std::map jumps_; + std::map jumps_; // keeps all removable phis which will be removed in one pass. std::unordered_map removable_phis_; diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/parse/parse_base.h index df2d1968a51..aad8be0d6e9 100644 --- a/mindspore/ccsrc/pipeline/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/parse/parse_base.h @@ -128,15 +128,15 @@ enum ClassInstanceTypeDef { }; // Convert python object to ValuePtr -bool ConvertData(const py::object& obj, ValuePtr* data, bool use_signature = false); +bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false); // Convert python obj to graph -FuncGraphPtr ConvertToFuncGraph(const py::object& obj, - const std::string& python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, + const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); // Parse the python object to graph -FuncGraphPtr ParsePythonCode(const py::object& obj, - const std::string& python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); +FuncGraphPtr ParsePythonCode(const py::object &obj, + const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.cc b/mindspore/ccsrc/pipeline/parse/python_adapter.cc index e2c86164d4a..df2f7d0d45d 100644 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.cc +++ b/mindspore/ccsrc/pipeline/parse/python_adapter.cc @@ -32,7 +32,7 @@ void set_use_signature_in_resolve(bool use_signature) noexcept { use_signature_i bool UseSignatureInResolve() { return use_signature_in_resolve_; } void set_python_env_flag(bool python_env) noexcept { python_env_ = python_env; } bool IsPythonEnv() { return python_env_; } -void SetPythonPath(const std::string& path) { +void SetPythonPath(const std::string &path) { // load the python module path (void)python_adapter::set_python_scoped(); py::module sys = py::module::import("sys"); @@ -62,7 +62,7 @@ std::shared_ptr set_python_scoped() { } // return the module of python -py::module GetPyModule(const std::string& module) { +py::module GetPyModule(const std::string &module) { if (!module.empty()) { return py::module::import(module.c_str()); } else { @@ -71,7 +71,7 @@ py::module GetPyModule(const std::string& module) { } // Get the obj of attr -py::object GetPyObjAttr(const py::object& obj, const std::string& attr) { +py::object GetPyObjAttr(const py::object &obj, const std::string &attr) { if (!attr.empty() && !py::isinstance(obj)) { if (py::hasattr(obj, attr.c_str())) { return obj.attr(attr.c_str()); @@ -81,7 +81,7 @@ py::object GetPyObjAttr(const py::object& obj, const std::string& attr) { return py::none(); } -py::object GetPyFn(const std::string& module, const std::string& name) { +py::object GetPyFn(const std::string &module, const std::string &name) { (void)python_adapter::set_python_scoped(); if (!module.empty() && !name.empty()) { py::module mod = py::module::import(module.c_str()); diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.h b/mindspore/ccsrc/pipeline/parse/python_adapter.h index 12cfc271868..98adcd4f731 100644 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.h +++ b/mindspore/ccsrc/pipeline/parse/python_adapter.h @@ -31,10 +31,10 @@ namespace mindspore { namespace parse { // A utility to call python interface namespace python_adapter { -py::module GetPyModule(const std::string& module); -py::object GetPyObjAttr(const py::object& obj, const std::string& attr); +py::module GetPyModule(const std::string &module); +py::object GetPyObjAttr(const py::object &obj, const std::string &attr); template -py::object CallPyObjMethod(const py::object& obj, const std::string& method, T... args) { +py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) { if (!method.empty() && !py::isinstance(obj)) { return obj.attr(method.c_str())(args...); } @@ -43,7 +43,7 @@ py::object CallPyObjMethod(const py::object& obj, const std::string& method, T.. // call python function of module template -py::object CallPyModFn(const py::module& mod, const std::string& function, T... args) { +py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) { if (!function.empty() && !py::isinstance(mod)) { return mod.attr(function.c_str())(args...); } @@ -57,12 +57,12 @@ bool UseSignatureInResolve(); std::shared_ptr set_python_scoped(); void ResetPythonScope(); bool IsPythonEnv(); -void SetPythonPath(const std::string& path); +void SetPythonPath(const std::string &path); void set_python_env_flag(bool python_env) noexcept; -py::object GetPyFn(const std::string& module, const std::string& name); +py::object GetPyFn(const std::string &module, const std::string &name); // Call the python function template -py::object CallPyFn(const std::string& module, const std::string& name, T... args) { +py::object CallPyFn(const std::string &module, const std::string &name, T... args) { (void)set_python_scoped(); if (!module.empty() && !name.empty()) { py::module mod = py::module::import(module.c_str()); diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc index f90fc5039c0..284512c9430 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/parse/resolve.cc @@ -71,7 +71,7 @@ bool SymbolResolver::Resolve() { namespace { // argument obj should be python Parameter object // it will be converted to Parameter node here -AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& obj) { +AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { MS_EXCEPTION_IF_NULL(func_graph); // parameter object should not be none @@ -128,7 +128,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& } } -bool ResolveObjectToNode(const FuncGraphPtr& func_graph, const py::object& obj, AnfNodePtr* const node) { +bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { AnfNodePtr output = nullptr; if (py::hasattr(obj, "__parameter__")) { auto param = ResolveParameterObj(func_graph, obj); @@ -171,12 +171,12 @@ bool ResolveObjectToNode(const FuncGraphPtr& func_graph, const py::object& obj, } // transform the ValueTuple or ValueList of graph node to make tuple of const graph node -bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const AnfNodePtr& node, - const ValueNodePtr& value_node, AnfNodePtr* const transformed) { +bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, + const ValueNodePtr &value_node, AnfNodePtr *const transformed) { MS_EXCEPTION_IF_NULL(value_node); - const auto& value_vec = GetValue>(value_node->value()); + const auto &value_vec = GetValue>(value_node->value()); bool has_graph_in_list = false; - for (auto& elemv : value_vec) { + for (auto &elemv : value_vec) { MS_EXCEPTION_IF_NULL(elemv); if (elemv->isa()) { FuncGraphPtr new_fg = elemv->cast(); @@ -196,10 +196,10 @@ bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const Anf auto make_list_op = NewValueNode(prim::kPrimMakeTuple); list_vec.emplace_back(make_list_op); (void)std::transform(std::begin(value_vec), std::end(value_vec), std::back_inserter(list_vec), - [](const ValuePtr& value) { return NewValueNode(value); }); + [](const ValuePtr &value) { return NewValueNode(value); }); FuncGraphPtr cnode_graph = nullptr; auto users = manager->node_users()[node]; - for (auto& use : users) { + for (auto &use : users) { auto use_node = use.first; MS_EXCEPTION_IF_NULL(use_node); if (use_node->isa()) { @@ -220,8 +220,8 @@ bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const Anf } } // namespace -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& name_space, const SymbolPtr& symbol, - const AnfNodePtr& node) { +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node) { if (node->func_graph() == nullptr || manager == nullptr) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; } @@ -253,7 +253,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& } namespace { -opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib& irpass) { +opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { opt::OptPassGroupMap map({ {"resolve", { @@ -266,7 +266,7 @@ opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib& ir } } // namespace -bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBasePtr& res, bool use_profile) { +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) { if (func_graph == nullptr || res == nullptr) { MS_LOG(ERROR) << "func_graph or resource is null"; return false; @@ -282,7 +282,7 @@ bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBa return true; } -bool ResolveAll(const FuncGraphManagerPtr& manager) { +bool ResolveAll(const FuncGraphManagerPtr &manager) { if (manager == nullptr) { MS_LOG(ERROR) << "func graph manager is null"; return false; @@ -301,7 +301,7 @@ bool ResolveAll(const FuncGraphManagerPtr& manager) { res->set_manager(manager); auto roots = manager->roots(); - for (auto& fg : roots) { + for (auto &fg : roots) { bool ret = ResolveFuncGraph(fg, res, false); if (!ret) { MS_EXCEPTION_IF_NULL(fg); diff --git a/mindspore/ccsrc/pipeline/parse/resolve.h b/mindspore/ccsrc/pipeline/parse/resolve.h index ccc22c72dc3..acabfaf54b3 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/parse/resolve.h @@ -39,7 +39,7 @@ namespace parse { // NameSpace class for resolving python code. class NameSpace : public Named { public: - NameSpace(const std::string& module, const py::object& obj) : Named(module), module_(module), obj_(obj) {} + NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {} ~NameSpace() override = default; MS_DECLARE_PARENT(NameSpace, Named); @@ -60,8 +60,8 @@ using NameSpacePtr = std::shared_ptr; // Symbol in NameSpace or Class which shall be resolved. class Symbol : public Named { public: - explicit Symbol(const std::string& symbol) : Named(symbol), symbol_(symbol) {} - explicit Symbol(const std::string& symbol, const std::string& name) : Named(name), symbol_(symbol) {} + explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {} + explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {} ~Symbol() override = default; MS_DECLARE_PARENT(Symbol, Named); @@ -79,7 +79,7 @@ using SymbolPtr = std::shared_ptr; // PyObjectWrapper class wrappers resolved python object for further processing. class PyObjectWrapper : public Named { public: - explicit PyObjectWrapper(const py::object& obj, const std::string name = "Python object") : Named(name), obj_(obj) {} + explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} ~PyObjectWrapper() override = default; MS_DECLARE_PARENT(PyObjectWrapper, Named); py::object obj() { return obj_; } @@ -92,7 +92,7 @@ class PyObjectWrapper : public Named { // ClassObject class wrappers dataclass class ClassObject : public PyObjectWrapper { public: - explicit ClassObject(const py::object& obj, const std::string name = "Python dataclass") + explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") : PyObjectWrapper(obj, name) {} ~ClassObject() override = default; MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); @@ -102,7 +102,7 @@ class ClassObject : public PyObjectWrapper { // ClassType class wrappers class name in python class ClassType : public PyObjectWrapper { public: - explicit ClassType(const py::object& obj, const std::string name = "Python class type") + explicit ClassType(const py::object &obj, const std::string name = "Python class type") : PyObjectWrapper(obj, name) {} ~ClassType() override = default; MS_DECLARE_PARENT(ClassType, PyObjectWrapper); @@ -112,7 +112,7 @@ class ClassType : public PyObjectWrapper { // SymbolResolver class for resolving symbol extracted from AnfNode. class SymbolResolver { public: - SymbolResolver(const NameSpacePtr& name_space, const SymbolPtr& symbol, const AnfNodePtr& node) + SymbolResolver(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) : namespace_(name_space), symbol_(symbol), resolved_node_(node) {} ~SymbolResolver() = default; @@ -124,7 +124,7 @@ class SymbolResolver { SymbolPtr symbol() { return symbol_; } - py::object& result() { return result_; } + py::object &result() { return result_; } AnfNodePtr resolved_node() { return resolved_node_; } @@ -141,15 +141,15 @@ class SymbolResolver { }; using SymbolResolverPtr = std::shared_ptr; // Resolve symbol in namespace. -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& name_space, const SymbolPtr& symbol, - const AnfNodePtr& node); +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node); // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). -bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBasePtr& res, bool use_profile = true); +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); // Resolve all graphs in manager which is defined outside of pipeline::Resource. // Mainly used for test cases or resolve graphs which will not be managed by manager. -bool ResolveAll(const FuncGraphManagerPtr& manager); +bool ResolveAll(const FuncGraphManagerPtr &manager); } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index b3eda4c37b1..6cdf6414433 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -48,7 +48,7 @@ using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; using mindspore::validator::Validate; -bool SimplifyDataStructuresPass(const ResourcePtr& res) { +bool SimplifyDataStructuresPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); @@ -57,7 +57,7 @@ bool SimplifyDataStructuresPass(const ResourcePtr& res) { abstract::AbstractBasePtrList args_spec; auto parameters = func_graph->parameters(); (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), - [](const AnfNodePtr& p) -> AbstractBasePtr { return p->abstract(); }); + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); res->set_func_graph(new_fg); res->set_args_spec(args_spec); @@ -65,7 +65,7 @@ bool SimplifyDataStructuresPass(const ResourcePtr& res) { } namespace { -OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ irpass.switch_simplify_, @@ -133,7 +133,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { return map_a; } -OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig b_1 = opt::OptPassConfig({ irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, @@ -157,7 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib& irpass) { return map; } -OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}); OptPassGroupMap map({ {"control_group", control_group}, @@ -173,7 +173,7 @@ OptPassGroupMap GetInferenceOptPreparePhases() { return prepare_map; } -OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); OptPassGroupMap map({{"prepare_group", prepare_group}}); return map; @@ -181,7 +181,7 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { static std::unordered_map> g_pass_opts = {}; -void InitOpt(const ResourcePtr& res) { +void InitOpt(const ResourcePtr &res) { if (g_pass_opts.size() == 0) { opt::irpass::OptimizeIRPassLib irpass; g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); @@ -193,13 +193,13 @@ void InitOpt(const ResourcePtr& res) { } // namespace void ReclaimOptimizer() { - for (auto& opt : g_pass_opts) { + for (auto &opt : g_pass_opts) { opt.second = nullptr; } g_pass_opts.clear(); } -bool OptPassGroup(const ResourcePtr& res, const std::string& name) { +bool OptPassGroup(const ResourcePtr &res, const std::string &name) { if (res->func_graph() == nullptr) { MS_LOG(ERROR) << "Opt passes int error"; return false; @@ -216,12 +216,12 @@ bool OptPassGroup(const ResourcePtr& res, const std::string& name) { return true; } -bool OptPassAGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_a"); } -bool OptPassBGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_b"); } -bool ControlGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_control"); } -bool PrepareGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_prepare"); } +bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } +bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } +bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } +bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } -bool AddControlDependPass(const ResourcePtr& res) { +bool AddControlDependPass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -237,7 +237,7 @@ bool AddControlDependPass(const ResourcePtr& res) { return true; } -bool CconvPass(const ResourcePtr& res) { +bool CconvPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr new_fg = LiftingClone(func_graph); @@ -245,14 +245,14 @@ bool CconvPass(const ResourcePtr& res) { return true; } -bool ValidatePass(const ResourcePtr& res) { +bool ValidatePass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); Validate(func_graph); return true; } -bool InferenceOptPreparePass(const ResourcePtr& res) { +bool InferenceOptPreparePass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); abstract::AbstractBasePtrList args_spec = res->args_spec(); diff --git a/mindspore/ccsrc/pipeline/pass.h b/mindspore/ccsrc/pipeline/pass.h index 3731d7e524c..2636879d018 100644 --- a/mindspore/ccsrc/pipeline/pass.h +++ b/mindspore/ccsrc/pipeline/pass.h @@ -30,11 +30,11 @@ using PassItem = std::pair>; extern std::vector kGePasses; extern std::vector kVmPasses; -bool CconvPass(const ResourcePtr& res); -bool ValidatePass(const ResourcePtr& res); -bool ConvertPrepareAdapt(const ResourcePtr& res); -bool AddControlDependPass(const ResourcePtr& res); -bool InferenceOptPreparePass(const ResourcePtr& res); +bool CconvPass(const ResourcePtr &res); +bool ValidatePass(const ResourcePtr &res); +bool ConvertPrepareAdapt(const ResourcePtr &res); +bool AddControlDependPass(const ResourcePtr &res); +bool InferenceOptPreparePass(const ResourcePtr &res); void ReclaimOptimizer(); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index cd4fe28db93..5b5cae40445 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -67,7 +67,7 @@ std::unordered_map& defaults) { +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); abstract::AbstractBasePtrList args_spec; @@ -147,7 +147,7 @@ py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple i ExecutorPy::ExecutorPy() {} -ResourcePtr ExecutorPy::GetResource(const std::string& phase) { +ResourcePtr ExecutorPy::GetResource(const std::string &phase) { MS_LOG(DEBUG) << "Phase size:" << info_.size(); if (info_.count(phase) == 0) { return nullptr; @@ -155,21 +155,21 @@ ResourcePtr ExecutorPy::GetResource(const std::string& phase) { return info_[phase]->resource; } -FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string& phase) { +FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { if (info_.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } return info_[phase]->func_graph; } -std::size_t ExecutorPy::ArgListSize(const std::string& phase) { +std::size_t ExecutorPy::ArgListSize(const std::string &phase) { if (info_.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } return info_[phase]->arg_list_size; } -compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string& phase) { +compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { ResourcePtr res = GetResource(phase); MS_EXCEPTION_IF_NULL(res); if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is()) { @@ -179,17 +179,17 @@ compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string& phase) { return nullptr; } -bool ExecutorPy::HasCompiled(const std::string& phase) const { +bool ExecutorPy::HasCompiled(const std::string &phase) const { if (info_.count(phase) == 0) { return false; } return true; } -py::bytes ExecutorPy::GetFuncGraphProto(const std::string& phase, const std::string& ir_type) { +py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) { FuncGraphPtr fg_ptr = GetFuncGraph(phase); if (fg_ptr == nullptr) { - for (auto& item : info_) { + for (auto &item : info_) { MS_LOG(DEBUG) << "Phase key is: " << item.first; } MS_LOG(EXCEPTION) << "Can not find func graph " << phase; @@ -214,34 +214,34 @@ py::bytes ExecutorPy::GetFuncGraphProto(const std::string& phase, const std::str MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; } -py::dict ExecutorPy::GetParameterLayout(const std::string& phase) { +py::dict ExecutorPy::GetParameterLayout(const std::string &phase) { MS_LOG(DEBUG) << "GetParameterLayout!"; std::string layout_graph = phase + kStepParallelGraph; auto graph = GetFuncGraph(layout_graph); return mindspore::parallel::GetParameterLayout(graph); } -py::dict ExecutorPy::GetCNodeStrategy(const std::string& phase) { +py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { MS_LOG(DEBUG) << "GetCNodeStrategy!"; std::string layout_graph = phase + kStepParallelGraph; auto graph = GetFuncGraph(layout_graph); return mindspore::parallel::GetCNodeStrategy(graph); } -py::dict ExecutorPy::GetAllreduceFusion(const std::string& phase) { +py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { MS_LOG(INFO) << "GetAllreduceFusion!"; auto graph = GetFuncGraph(phase); return mindspore::parallel::GetAllreduceFusion(graph); } -void ExecutorPy::DelNetRes(const std::string& id) { +void ExecutorPy::DelNetRes(const std::string &id) { #ifdef ENABLE_GE FinalizeGe(); #endif if (executor_ != nullptr) { bool flag = false; auto tmp_info = info_; - for (auto& item : tmp_info) { + for (auto &item : tmp_info) { if (item.first.find(id) != string::npos) { MS_LOG(INFO) << "Delete network res:" << item.first; (void)info_.erase(item.first); @@ -271,7 +271,7 @@ ExecutorPy::~ExecutorPy() { ConfigManager::GetInstance().ResetConfig(); } -void ExecutorPy::SaveCompiledGraph(const std::string& phase_s) { +void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { // save the graph to ExecutorPy FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -294,7 +294,7 @@ void ExecutorPy::SaveCompiledGraph(const std::string& phase_s) { MS_LOG(INFO) << "End save compiled func graph!"; } -bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string& phase_s) const { +bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { std::string phase_prefix = GetPhasePrefix(phase_s); if (use_vm && phase_prefix == "export") { @@ -313,7 +313,7 @@ void ExecutorPy::GetGeBackendPolicy() const { } } -bool ExecutorPy::CompileInner(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm) { +bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { MS_LOG(DEBUG) << "Start ExecutorPy compile!"; if ((!py::isinstance(phase))) { MS_LOG(ERROR) << "Arg phase must be string."; @@ -376,7 +376,7 @@ bool ExecutorPy::CompileInner(const py::object& obj, const py::tuple& args, cons return true; } -void ExecutorPy::ReleaseResource(const py::object& phase) { +void ExecutorPy::ReleaseResource(const py::object &phase) { ResourcePtr res = GetResource(py::cast(phase)); if (res != nullptr) { res->Clean(); @@ -385,18 +385,18 @@ void ExecutorPy::ReleaseResource(const py::object& phase) { ReclaimOptimizer(); } -static std::string PrintArgs(const py::tuple& args) { +static std::string PrintArgs(const py::tuple &args) { py::print(args); return ""; } -bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm) { +bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { bool ret_value = false; try { MS_LOG(DEBUG) << PrintArgs(args); ret_value = CompileInner(obj, args, phase, use_vm); - } catch (const py::error_already_set& ex) { + } catch (const py::error_already_set &ex) { // print function call stack info before release std::ostringstream oss; trace::TraceGraphInfer(); @@ -409,13 +409,13 @@ bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py: // re-throw this exception to Python interpreter to handle it throw(py::error_already_set(ex)); - } catch (const py::type_error& ex) { + } catch (const py::type_error &ex) { ReleaseResource(phase); throw py::type_error(ex); - } catch (const py::value_error& ex) { + } catch (const py::value_error &ex) { ReleaseResource(phase); throw py::value_error(ex); - } catch (const std::exception& ex) { + } catch (const std::exception &ex) { ReleaseResource(phase); // re-throw this exception to Python interpreter to handle it throw(std::runtime_error(ex.what())); @@ -432,7 +432,7 @@ bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py: // get MindSpore Intermediate Representation File std::string GetMsIrFile(void) { std::string file; - const char* path = getenv("MS_IR_FILE"); + const char *path = getenv("MS_IR_FILE"); if (path == nullptr) { return file; } @@ -446,7 +446,7 @@ std::string GetMsIrFile(void) { return file; } -void RunPipelineAction(const ActionItem& action, pipeline::ResourcePtr resource, bool* result) { +void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, bool *result) { MS_EXCEPTION_IF_NULL(resource); MS_EXCEPTION_IF_NULL(result); @@ -472,7 +472,7 @@ void RunPipelineAction(const ActionItem& action, pipeline::ResourcePtr resource, } auto manager = resource->manager(); MS_EXCEPTION_IF_NULL(manager); - for (auto& graph : graphs) { + for (auto &graph : graphs) { manager->AddFuncGraph(graph); } resource->set_func_graph(graphs[0]); @@ -491,9 +491,9 @@ void Pipeline::Run() { WITH(MsProfile::GetProfile())[&user_graph, this]() { int i = 0; - for (auto& action : actions_) { + for (auto &action : actions_) { #ifdef ENABLE_TIMELINE - DumpTime& dump_time = DumpTime::GetInstance(); + DumpTime &dump_time = DumpTime::GetInstance(); dump_time.Record(action.first, GetTime(), true); #endif bool result = true; @@ -575,7 +575,7 @@ void Pipeline::Run() { MS_LOG(INFO) << "End"; } -void ExecutorPy::ProcessVmArg(const py::tuple& args, const std::string& phase, VectorRef* arg_list) { +void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) { std::size_t size = args.size(); for (std::size_t i = 0; i < size; i++) { @@ -604,7 +604,7 @@ void ExecutorPy::ProcessVmArg(const py::tuple& args, const std::string& phase, V } } -py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { +py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { std::size_t size = args.size(); if (!py::isinstance(phase)) { MS_LOG(EXCEPTION) << "Run failed, phase input is not a str"; @@ -649,8 +649,8 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { return BaseRefToPyData(value); } -FuncGraphPtr ExecutorPy::BuildGraph(const py::dict& init_params, const std::string& phase, - const py::object& broadcast_params) { +FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params) { #if (ENABLE_GE || ENABLE_D) return BuildDFGraph(info_, init_params, phase, broadcast_params); #else @@ -658,15 +658,15 @@ FuncGraphPtr ExecutorPy::BuildGraph(const py::dict& init_params, const std::stri #endif } -void ExecutorPy::RunInitGraph(const py::dict& init_params, const std::string& phase) { +void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { #if ENABLE_GE RunGEInitGraph(init_params, phase); #endif } -bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase) { +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { std::string name = MsContext::GetInstance()->backend_policy(); if (name == kMsConvert || name == kMsVm) { return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes); @@ -682,16 +682,16 @@ bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t ba return false; } -bool InitExecDatasetVm(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes) { +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes) { MS_LOG(INFO) << "Start InitDataSet Entry"; std::vector int_input_indexes; (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), [](int64_t item) { return static_cast(item); }); std::vector> int_shapes; (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes), - [](const std::vector& item) { + [](const std::vector &item) { std::vector vector_item; (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item), [](int64_t inner_item) { return static_cast(inner_item); }); @@ -774,7 +774,7 @@ void FinalizeHccl() { #endif } -void ExportGraph(const std::string& file_name, const std::string&, const std::string& phase) { +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) { #if (ENABLE_GE || ENABLE_D) ExportDFGraph(file_name, phase); #endif diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h index a0d7a191988..865c961ac13 100644 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ b/mindspore/ccsrc/pipeline/pipeline.h @@ -43,7 +43,7 @@ namespace py = pybind11; class Pipeline { public: - Pipeline(const ResourcePtr& res, const std::vector& actions) : resource_(res), actions_(actions) {} + Pipeline(const ResourcePtr &res, const std::vector &actions) : resource_(res), actions_(actions) {} ~Pipeline() = default; @@ -69,35 +69,35 @@ class ExecutorPy : public std::enable_shared_from_this { ~ExecutorPy(); - void SaveCompiledGraph(const std::string& phase_s); - bool CompileInner(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm); - bool Compile(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm); + void SaveCompiledGraph(const std::string &phase_s); + bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); + bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); - void ProcessVmArg(const py::tuple& args, const std::string& phase, VectorRef* arg_list); + void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list); // for pynative mode when use_vm is on - py::object Run(const py::tuple& args, const py::object& phase); - ResourcePtr GetResource(const std::string& phase); - FuncGraphPtr GetFuncGraph(const std::string& phase); - py::bytes GetFuncGraphProto(const std::string& phase, const std::string& type); - std::size_t ArgListSize(const std::string& phase); - compile::VmEvalFuncPtr GetVmEvalFunc(const std::string& phase); - bool HasCompiled(const std::string& phase) const; + py::object Run(const py::tuple &args, const py::object &phase); + ResourcePtr GetResource(const std::string &phase); + FuncGraphPtr GetFuncGraph(const std::string &phase); + py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); + std::size_t ArgListSize(const std::string &phase); + compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); + bool HasCompiled(const std::string &phase) const; - FuncGraphPtr BuildGraph(const py::dict& init_params, const std::string& phase, - const py::object& broadcast_params = {}); - void RunInitGraph(const py::dict& init_params, const std::string& phase); - py::dict GetParameterLayout(const std::string& phase); - py::dict GetCNodeStrategy(const std::string& phase); - py::dict GetAllreduceFusion(const std::string& phase); - void DelNetRes(const std::string& id); - void ReleaseResource(const py::object& phase); + FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params = {}); + void RunInitGraph(const py::dict &init_params, const std::string &phase); + py::dict GetParameterLayout(const std::string &phase); + py::dict GetCNodeStrategy(const std::string &phase); + py::dict GetAllreduceFusion(const std::string &phase); + void DelNetRes(const std::string &id); + void ReleaseResource(const py::object &phase); static void ClearRes(); private: ExecutorPy(); - void ConvertObjectToTensors(const py::dict& dict, std::map* tensors); - bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string& phase_s) const; + void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); + bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const; void GetGeBackendPolicy() const; std::map info_; @@ -107,10 +107,10 @@ class ExecutorPy : public std::enable_shared_from_this { using ExecutorPyPtr = std::shared_ptr; // Generate a key for mapping function graph -py::tuple GenerateKey(const std::string& name, const std::unordered_map& defaults); +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults); py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs); -bool InitDistribute(const std::map& options); +bool InitDistribute(const std::map &options); void ResetOpId(); void InitHccl(); @@ -121,17 +121,17 @@ void FinalizeGe(); void ClearResAtexit(); void ReleaseGeTsd(); -void ExportGraph(const std::string& file_name, const std::string&, const std::string& phase); +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); // init and exec dataset sub graph -bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase); +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase); // Build and run dataset subgraph for ms backend -bool InitExecDatasetVm(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes); +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.cc b/mindspore/ccsrc/pipeline/pipeline_ge.cc index 6ce0ea53168..e3b10b73b0a 100644 --- a/mindspore/ccsrc/pipeline/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/pipeline_ge.cc @@ -46,7 +46,7 @@ using mindspore::transform::MeTensorPtr; using mindspore::transform::Status; using mindspore::transform::TransformUtil; -void DoExecNonInputGraph(const std::string& phase) { +void DoExecNonInputGraph(const std::string &phase) { std::vector ge_tensors; std::vector ge_outputs; transform::RunOptions run_options; @@ -68,7 +68,7 @@ void DoExecNonInputGraph(const std::string& phase) { } } -void SetGeOption(const std::map& options) { +void SetGeOption(const std::map &options) { ConfigManager::GetInstance().set_ge_initialize_options(options); } @@ -108,11 +108,11 @@ Status CreateSessionAndGraphRunner(bool is_training = true) { return Status::SUCCESS; } -bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase) { +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { std::vector ge_types; - (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr& i) -> int64_t { + (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr &i) -> int64_t { return transform::TransformUtil::ConvertDataType(i->type_id()); }); @@ -145,7 +145,7 @@ bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batc return true; } -void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) { +void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) { for (auto item : dict) { if ((!py::isinstance(item.first))) { MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; @@ -156,11 +156,11 @@ void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) if (py::isinstance(item.second.attr("default_input"))) { // convert float to tensor with shape([1]) tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); - *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); + *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); } else if (py::isinstance(item.second.attr("default_input"))) { // convert int to tensor with shape([1]) tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); - *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); + *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { // cast tensor tensor = py::cast>(item.second.attr("default_input")); @@ -173,8 +173,8 @@ void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) } } -bool AddDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params) { +bool AddDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { FuncGraphPtr anf_graph = info.at(phase)->func_graph; DfGraphConvertor convertor(anf_graph); @@ -237,8 +237,8 @@ bool AddDFGraph(const std::map& info, const py::di return true; } -FuncGraphPtr BuildDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params) { +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { if (info.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } @@ -268,13 +268,13 @@ FuncGraphPtr BuildDFGraph(const std::map& info, co return anf_graph; } -void RunGEInitGraph(const py::dict& init_params, const std::string& phase) { +void RunGEInitGraph(const py::dict &init_params, const std::string &phase) { MS_LOG(DEBUG) << "ExecInitGraph start."; TensorOrderMap inputs_with_name{}; ConvertObjectToTensors(init_params, &inputs_with_name); std::vector inputs; (void)std::transform(inputs_with_name.begin(), inputs_with_name.end(), std::back_inserter(inputs), - [](const std::pair& item) { return item.second; }); + [](const std::pair &item) { return item.second; }); std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); if (ge_tensors.size() != inputs.size()) { @@ -317,7 +317,7 @@ void RunGEInitGraph(const py::dict& init_params, const std::string& phase) { } } -py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::tuple& data, size_t* count) { +py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) { MS_EXCEPTION_IF_NULL(cnode_data); if (*count >= data.size()) { MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() @@ -350,7 +350,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::t return std::move(tp); } -py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, size_t* count) { +py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) { MS_EXCEPTION_IF_NULL(output_node); if (output_node->isa()) { @@ -387,8 +387,8 @@ py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, return ExtractGeneralCnodeRet(output_c->abstract(), data, count); } -std::shared_ptr DoExecGraph(const FuncGraphPtr& graph, const std::vector& inputs, - const std::string& phase) { +std::shared_ptr DoExecGraph(const FuncGraphPtr &graph, const std::vector &inputs, + const std::string &phase) { std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); if (ge_tensors.size() != inputs.size()) { MS_LOG(EXCEPTION) << "Convert me args to ge tensor error."; @@ -438,8 +438,8 @@ std::shared_ptr DoExecGraph(const FuncGraphPtr& graph, const std::ve return ret; } -void ProcessGeArg(const std::map& info, const py::tuple& args, const std::string& phase, - std::vector* inputs) { +void ProcessGeArg(const std::map &info, const py::tuple &args, const std::string &phase, + std::vector *inputs) { // check the arg and use the ExecutorPy args std::size_t size = args.size(); @@ -470,8 +470,8 @@ void ProcessGeArg(const std::map& info, const py:: } } -py::object ExecDFGraph(const std::map& info, const py::tuple& args, - const std::string& phase) { +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase) { std::string phase_prefix = GetPhasePrefix(phase); if (phase_prefix == "save") { @@ -514,7 +514,7 @@ py::object ExecDFGraph(const std::map& info, const MS_LOG(EXCEPTION) << "Exec graph failed"; } } -void ExportDFGraph(const std::string& file_name, const std::string& phase) { +void ExportDFGraph(const std::string &file_name, const std::string &phase) { MS_LOG(DEBUG) << "ExportGraph Begin"; transform::DfGraphWrapperPtr wrap_ptr = DfGraphManager::GetInstance().GetGraphByName(phase); if (wrap_ptr == nullptr) { diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.h b/mindspore/ccsrc/pipeline/pipeline_ge.h index c3779fd9820..9dc15246822 100644 --- a/mindspore/ccsrc/pipeline/pipeline_ge.h +++ b/mindspore/ccsrc/pipeline/pipeline_ge.h @@ -34,22 +34,22 @@ namespace pipeline { namespace py = pybind11; -void SetGeOption(const std::map& options); +void SetGeOption(const std::map &options); -void RunGEInitGraph(const py::dict& init_params, const std::string& phase); +void RunGEInitGraph(const py::dict &init_params, const std::string &phase); -py::object ExecDFGraph(const std::map& info, const py::tuple& args, - const std::string& phase = "train"); +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase = "train"); -FuncGraphPtr BuildDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params = {}); +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params = {}); // init and exec dataset sub graph for GE backend -bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase); +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase); -void ExportDFGraph(const std::string& file_name, const std::string& phase); +void ExportDFGraph(const std::string &file_name, const std::string &phase); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/remove_value_node_dup.cc index 7937c3e55f3..0b7401345ac 100644 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc +++ b/mindspore/ccsrc/pipeline/remove_value_node_dup.cc @@ -24,9 +24,9 @@ namespace mindspore { namespace pipeline { -void TryToDoReplace(FuncGraphManager* const manager, const AnfNodePtr& node, HashCache* const hash_cache, - HashValue* const hash_value) { - const auto& to_check_value = GetValueNode(node); +void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache, + HashValue *const hash_value) { + const auto &to_check_value = GetValueNode(node); MS_EXCEPTION_IF_NULL(to_check_value); // Calculate hash value. @@ -46,14 +46,14 @@ void TryToDoReplace(FuncGraphManager* const manager, const AnfNodePtr& node, Has return; } - auto& bucket = bucket_iter->second; + auto &bucket = bucket_iter->second; // Check if need to replace node with value node already met. - for (const auto& v : bucket) { + for (const auto &v : bucket) { // Already met and cached. if (v == node) { return; } - const auto& existed_value = GetValueNode(v); + const auto &existed_value = GetValueNode(v); MS_EXCEPTION_IF_NULL(existed_value); auto equal = [&]() -> bool { if (existed_value->isa() && to_check_value->isa()) { diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/remove_value_node_dup.h index 8fbb3f27557..8f670c7dcfd 100644 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.h +++ b/mindspore/ccsrc/pipeline/remove_value_node_dup.h @@ -27,7 +27,7 @@ namespace pipeline { using HashCache = std::unordered_map>; using HashValue = std::unordered_map; -void TryToDoReplace(FuncGraphManager* manager, const AnfNodePtr& node, HashCache* hash_cache, HashValue* hash_value); +void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/resource.cc b/mindspore/ccsrc/pipeline/resource.cc index 18695518bef..50ccef2f44c 100644 --- a/mindspore/ccsrc/pipeline/resource.cc +++ b/mindspore/ccsrc/pipeline/resource.cc @@ -32,7 +32,7 @@ namespace mindspore { // namespace to support opmap definition namespace pipeline { -MethodMap& GetMethodMap() { +MethodMap &GetMethodMap() { static MethodMap method_map = {{kObjectTypeString, { {"__bool__", std::string("str_bool")} // C.str_bool @@ -178,7 +178,7 @@ MethodMap& GetMethodMap() { return method_map; } -Resource::Resource(const py::object& obj) +Resource::Resource(const py::object &obj) : engine_(std::make_shared(abstract::GetPrimEvaluatorConstructors(), manager_)), input_(obj), is_cleaned_(false) {} @@ -197,7 +197,7 @@ Resource::~Resource() { if (!is_cleaned_) { try { Clean(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what(); } catch (...) { MS_LOG(ERROR) << "Exception when cleaning resource."; @@ -205,9 +205,9 @@ Resource::~Resource() { } } -bool Resource::IsTypeInMethodMap(const TypeId& type) { +bool Resource::IsTypeInMethodMap(const TypeId &type) { TypeId type_id = NormalizeTypeId(type); - const MethodMap& method_map = GetMethodMap(); + const MethodMap &method_map = GetMethodMap(); auto iter = method_map.find(static_cast(type_id)); if (iter != method_map.end()) { return true; @@ -215,9 +215,9 @@ bool Resource::IsTypeInMethodMap(const TypeId& type) { return false; } -Any Resource::GetMethodPtr(const TypeId& type, const std::string& name) { +Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { TypeId type_id = NormalizeTypeId(type); - const MethodMap& method_map = GetMethodMap(); + const MethodMap &method_map = GetMethodMap(); auto iter = method_map.find(static_cast(type_id)); if (iter == method_map.end()) { MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; diff --git a/mindspore/ccsrc/pipeline/resource.h b/mindspore/ccsrc/pipeline/resource.h index 15ab70db141..0c1348fd943 100644 --- a/mindspore/ccsrc/pipeline/resource.h +++ b/mindspore/ccsrc/pipeline/resource.h @@ -46,7 +46,7 @@ class InferenceResource; using MethodMap = std::unordered_map>; -MethodMap& GetMethodMap(); +MethodMap &GetMethodMap(); class ResourceBase { public: @@ -56,20 +56,20 @@ class ResourceBase { FuncGraphManagerPtr manager() { return manager_; } // set a manager defined outside which will not manage the graphs. - void set_manager(const FuncGraphManagerPtr& manager) { manager_ = manager; } + void set_manager(const FuncGraphManagerPtr &manager) { manager_ = manager; } - std::unordered_map& results() { return results_; } + std::unordered_map &results() { return results_; } - void SetResult(const std::string& key, const Any& value) { results_[key] = value; } + void SetResult(const std::string &key, const Any &value) { results_[key] = value; } - Any GetResult(const std::string& key) { + Any GetResult(const std::string &key) { if (results_.count(key) == 0) { MS_LOG(EXCEPTION) << "this key is not in resource list:" << key; } return results_[key]; } - bool HasResult(const std::string& key) const { return results_.count(key) != 0; } + bool HasResult(const std::string &key) const { return results_.count(key) != 0; } std::unordered_map results_; @@ -81,23 +81,23 @@ using ResourceBasePtr = std::shared_ptr; class Resource : public ResourceBase { public: - explicit Resource(const py::object& obj = py::none()); + explicit Resource(const py::object &obj = py::none()); ~Resource() override; abstract::AnalysisEnginePtr engine() { return engine_; } - static bool IsTypeInMethodMap(const TypeId& type); + static bool IsTypeInMethodMap(const TypeId &type); - static Any GetMethodPtr(const TypeId& type, const std::string& name); + static Any GetMethodPtr(const TypeId &type, const std::string &name); - const py::object& input() const { return input_; } + const py::object &input() const { return input_; } FuncGraphPtr func_graph() const { return func_graph_; } - void set_func_graph(const FuncGraphPtr& func_graph) { func_graph_ = func_graph; } + void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } - const abstract::AbstractBasePtrList& args_spec() const { return args_spec_; } - void set_args_spec(const abstract::AbstractBasePtrList& args_spec) { args_spec_ = args_spec; } + const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; } + void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; } // Reclaim resource and clear the cache. // ExecutorPy::Compile() can be called multiple times, so cache diff --git a/mindspore/ccsrc/pipeline/static_analysis/dshape.cc b/mindspore/ccsrc/pipeline/static_analysis/dshape.cc index 15aa71ba1e4..183ec772fff 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/dshape.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/dshape.cc @@ -26,31 +26,31 @@ namespace mindspore { namespace abstract { // used for print BaseShape content -std::ostream& operator<<(std::ostream& os, const BaseShape& bs) { +std::ostream &operator<<(std::ostream &os, const BaseShape &bs) { os << bs.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr bs) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr bs) { MS_EXCEPTION_IF_NULL(bs); os << bs->ToString(); return os; } -bool BaseShape::operator==(const BaseShape& other) const { +bool BaseShape::operator==(const BaseShape &other) const { if (tid() != other.tid()) { return false; } return true; } -bool BaseShape::operator!=(const BaseShape& other) const { return !(*this == other); } +bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == other); } std::string Shape::ToString() const { std::ostringstream buffer; bool f_begin = true; buffer << "("; - for (auto& x : shape_) { + for (auto &x : shape_) { if (!f_begin) { buffer << ", "; } else { @@ -72,11 +72,11 @@ std::string Shape::DumpText() const { return buffer.str(); } -bool Shape::operator==(const BaseShape& other) const { +bool Shape::operator==(const BaseShape &other) const { if (tid() != other.tid()) { return false; } - return shape_ == static_cast(other).shape_; + return shape_ == static_cast(other).shape_; } const int Shape::SHP_ANY; @@ -111,11 +111,11 @@ BaseShapePtrList SequeueShape::ElementsClone() const { } template -bool SequeueShape::SequeueEqual(const BaseShape& other) const { +bool SequeueShape::SequeueEqual(const BaseShape &other) const { if (tid() != other.tid()) { return false; } - auto other_shapes = static_cast(other).p_shapes_; + auto other_shapes = static_cast(other).p_shapes_; if (other_shapes.size() != p_shapes_.size()) { return false; } @@ -126,8 +126,8 @@ bool SequeueShape::SequeueEqual(const BaseShape& other) const { } return true; } -template bool SequeueShape::SequeueEqual(const BaseShape&) const; -template bool SequeueShape::SequeueEqual(const BaseShape&) const; +template bool SequeueShape::SequeueEqual(const BaseShape &) const; +template bool SequeueShape::SequeueEqual(const BaseShape &) const; const std::shared_ptr kNoShape = std::make_shared(); } // namespace abstract diff --git a/mindspore/ccsrc/pipeline/static_analysis/dshape.h b/mindspore/ccsrc/pipeline/static_analysis/dshape.h index 6debe061c88..3e850e309b4 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/dshape.h +++ b/mindspore/ccsrc/pipeline/static_analysis/dshape.h @@ -41,8 +41,8 @@ class BaseShape : public Base { ~BaseShape() override = default; MS_DECLARE_PARENT(BaseShape, Base) - virtual bool operator==(const BaseShape& other) const; - bool operator!=(const BaseShape& other) const; + virtual bool operator==(const BaseShape &other) const; + bool operator!=(const BaseShape &other) const; std::size_t hash() const override { return tid(); } // return a deep copy @@ -62,16 +62,16 @@ class Shape : public BaseShape { public: static const int SHP_ANY = -1; Shape() : shape_() {} - Shape(const std::initializer_list& list) : shape_(list) {} - explicit Shape(const std::vector& list) : shape_(list) {} + Shape(const std::initializer_list &list) : shape_(list) {} + explicit Shape(const std::vector &list) : shape_(list) {} ~Shape() override = default; MS_DECLARE_PARENT(Shape, BaseShape) std::string ToString() const override; std::string DumpText() const override; - bool operator==(const BaseShape& other) const override; + bool operator==(const BaseShape &other) const override; BaseShapePtr Clone() const override { return std::make_shared(shape_); } void Broaden() override; - std::vector& shape() { return shape_; } + std::vector &shape() { return shape_; } std::vector shape_; // use SHP_ANY to implement the any shape in python }; @@ -81,7 +81,7 @@ using ShapePtrList = std::vector; class SequeueShape : public BaseShape { public: SequeueShape() : p_shapes_() {} - explicit SequeueShape(const BaseShapePtrList& shapes) : p_shapes_(shapes) {} + explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {} ~SequeueShape() override = default; MS_DECLARE_PARENT(SequeueShape, BaseShape) @@ -89,9 +89,9 @@ class SequeueShape : public BaseShape { BaseShapePtrList ElementsClone() const; template - bool SequeueEqual(const BaseShape& other) const; + bool SequeueEqual(const BaseShape &other) const; - const BaseShapePtrList& shape() const { return p_shapes_; } + const BaseShapePtrList &shape() const { return p_shapes_; } size_t size() const { return p_shapes_.size(); } const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; } @@ -103,7 +103,7 @@ using SequeueShapePtr = std::shared_ptr; class TupleShape : public SequeueShape { public: TupleShape() : SequeueShape() {} - explicit TupleShape(const BaseShapePtrList& shapes) : SequeueShape(shapes) {} + explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} ~TupleShape() override = default; MS_DECLARE_PARENT(TupleShape, SequeueShape) @@ -111,14 +111,14 @@ class TupleShape : public SequeueShape { BaseShapePtr Clone() const override { return std::make_shared(ElementsClone()); } - bool operator==(const BaseShape& other) const override { return SequeueEqual(other); } + bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } }; using TupleShapePtr = std::shared_ptr; class ListShape : public SequeueShape { public: ListShape() : SequeueShape() {} - explicit ListShape(const BaseShapePtrList& shapes) : SequeueShape(shapes) {} + explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} ~ListShape() override = default; MS_DECLARE_PARENT(ListShape, SequeueShape) @@ -126,7 +126,7 @@ class ListShape : public SequeueShape { BaseShapePtr Clone() const override { return std::make_shared(SequeueShape::ElementsClone()); } - bool operator==(const BaseShape& other) const override { return SequeueEqual(other); } + bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } }; using ListShapePtr = std::shared_ptr; } // namespace abstract diff --git a/mindspore/ccsrc/pipeline/validator.cc b/mindspore/ccsrc/pipeline/validator.cc index 0fe32188131..73a54bb1807 100644 --- a/mindspore/ccsrc/pipeline/validator.cc +++ b/mindspore/ccsrc/pipeline/validator.cc @@ -39,7 +39,7 @@ using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractType; -void ValidateOperation(const AnfNodePtr& node) { +void ValidateOperation(const AnfNodePtr &node) { if (!IsValueNode(node)) { return; } @@ -60,7 +60,7 @@ void ValidateOperation(const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); } -void ValidateAbstract(const AnfNodePtr& node) { +void ValidateAbstract(const AnfNodePtr &node) { if (node == nullptr) { MS_LOG(WARNING) << "Node to validate is invalid"; return; @@ -105,11 +105,11 @@ void ValidateAbstract(const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); } -void Validate(const FuncGraphPtr& fg) { +void Validate(const FuncGraphPtr &fg) { FuncGraphManagerPtr mgr = Manage(fg, false); MS_EXCEPTION_IF_NULL(mgr); - AnfNodeSet& all_nodes = mgr->all_nodes(); - for (const auto& anf_node : all_nodes) { + AnfNodeSet &all_nodes = mgr->all_nodes(); + for (const auto &anf_node : all_nodes) { ValidateOperation(anf_node); ValidateAbstract(anf_node); } diff --git a/mindspore/ccsrc/pipeline/validator.h b/mindspore/ccsrc/pipeline/validator.h index 9944078e6c7..61f74703496 100644 --- a/mindspore/ccsrc/pipeline/validator.h +++ b/mindspore/ccsrc/pipeline/validator.h @@ -29,9 +29,9 @@ namespace mindspore { namespace validator { -void Validate(const FuncGraphPtr& func_graph); -void ValidateAbstract(const AnfNodePtr& node); -void ValidateOperation(const AnfNodePtr& node); +void Validate(const FuncGraphPtr &func_graph); +void ValidateAbstract(const AnfNodePtr &node); +void ValidateOperation(const AnfNodePtr &node); } // namespace validator } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc index f0077ef6cdf..c9ef381f16c 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc @@ -121,7 +121,7 @@ bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) co return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE; } -void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr& mem_buf) { +void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) { MS_EXCEPTION_IF_NULL(mem_buf); auto mem_block = FindMemBlock(mem_buf->device_addr_); MS_EXCEPTION_IF_NULL(mem_block); @@ -160,7 +160,7 @@ void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr device_addr) { CombineMemBuf(mem_block, device_addr); } -void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr& mem_block, const DeviceMemPtr device_addr) { +void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr) { MS_EXCEPTION_IF_NULL(mem_block); MS_EXCEPTION_IF_NULL(device_addr); auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h index dcf735814ca..c6287560701 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h @@ -61,7 +61,7 @@ class DynamicMemBlock { DynamicMemBlock() = default; DynamicMemBlock(DeviceMemPtr addr_base, size_t size) : device_addr_base_(addr_base), mem_block_size_(size) {} ~DynamicMemBlock() { block_all_mem_buf_map_.clear(); } - const DeviceMemPtr& device_addr() const { return device_addr_base_; } + const DeviceMemPtr &device_addr() const { return device_addr_base_; } size_t size() const { return mem_block_size_; } // The map of all memory buf in this memory block by device address. DeviceAddrMapMemBuf block_all_mem_buf_map_; @@ -92,8 +92,8 @@ class DynamicMemPoolBestFit { size_t used_mem_peak_statistics() const { return used_mem_peak_statistics_; } // The related interface of device memory real operation, needs override by device type. - virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) = 0; - virtual bool FreeDeviceMem(const DeviceMemPtr& addr) = 0; + virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) = 0; + virtual bool FreeDeviceMem(const DeviceMemPtr &addr) = 0; virtual size_t free_mem_size() = 0; virtual size_t total_mem_size() = 0; @@ -113,14 +113,14 @@ class DynamicMemPoolBestFit { // Judge whether need divide the memory buf by alloc size and memory buf size. bool IsDivide(size_t tensor_size, size_t mem_buf_size) const; // Divide the memory buf by alloc size. - void DivideMemBuf(size_t size, const DynamicMemBufPtr& mem_buf); + void DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf); // Find the memory block by deivce address. DynamicMemBlockPtr FindMemBlock(const DeviceMemPtr device_addr); // The Comparator of memory block by device address, because memory blocks are arranged in order by device address. static bool CmpMemBlock(const DeviceMemPtr device_addr, const DynamicMemBlockPtr mem_block); // Combine the memory buf when memory free, to avoid the memory fragmentation. - void CombineMemBuf(const DynamicMemBlockPtr& mem_block, const DeviceMemPtr device_addr); + void CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr); // Erase the idle memory buf by size and device address when idle memory buf is combined. void EraseIdleMemBuf(size_t size, const DeviceMemPtr device_addr); diff --git a/mindspore/ccsrc/predict/generator/ir/ir_model.h b/mindspore/ccsrc/predict/generator/ir/ir_model.h index bf1c057b5ff..82bd2aad3f3 100644 --- a/mindspore/ccsrc/predict/generator/ir/ir_model.h +++ b/mindspore/ccsrc/predict/generator/ir/ir_model.h @@ -23,7 +23,7 @@ namespace mindspore { namespace generator { class IRModel { public: - void SetIrTaskInfos(const std::vector& ir_tasks); + void SetIrTaskInfos(const std::vector &ir_tasks); IRModel() = default; ~IRModel(); diff --git a/mindspore/ccsrc/pybind_api/api_register.h b/mindspore/ccsrc/pybind_api/api_register.h index 2c1b622f31f..8bab751267d 100644 --- a/mindspore/ccsrc/pybind_api/api_register.h +++ b/mindspore/ccsrc/pybind_api/api_register.h @@ -29,19 +29,19 @@ namespace py = pybind11; namespace mindspore { -using PybindDefineFunc = std::function; +using PybindDefineFunc = std::function; class PybindDefineRegister { public: - static void Register(const std::string& name, const PybindDefineFunc& fn) { + static void Register(const std::string &name, const PybindDefineFunc &fn) { return GetSingleton().RegisterFn(name, fn); } - PybindDefineRegister(const PybindDefineRegister&) = delete; + PybindDefineRegister(const PybindDefineRegister &) = delete; - PybindDefineRegister& operator=(const PybindDefineRegister&) = delete; + PybindDefineRegister &operator=(const PybindDefineRegister &) = delete; - static std::map& AllFuncs() { return GetSingleton().fns_; } + static std::map &AllFuncs() { return GetSingleton().fns_; } std::map fns_; @@ -50,14 +50,14 @@ class PybindDefineRegister { virtual ~PybindDefineRegister() = default; - static PybindDefineRegister& GetSingleton(); + static PybindDefineRegister &GetSingleton(); - void RegisterFn(const std::string& name, const PybindDefineFunc& fn) { fns_[name] = fn; } + void RegisterFn(const std::string &name, const PybindDefineFunc &fn) { fns_[name] = fn; } }; class PybindDefineRegisterer { public: - PybindDefineRegisterer(const std::string& name, const PybindDefineFunc& fn) { + PybindDefineRegisterer(const std::string &name, const PybindDefineFunc &fn) { PybindDefineRegister::Register(name, fn); } ~PybindDefineRegisterer() = default; diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pynative/base.h index d8675adc9ce..37ff000b045 100644 --- a/mindspore/ccsrc/pynative/base.h +++ b/mindspore/ccsrc/pynative/base.h @@ -58,7 +58,7 @@ struct OpExecInfo { py::dict op_attrs; }; using OpExecInfoPtr = std::shared_ptr; -OpExecInfoPtr GenerateOpExecInfo(const py::args& args); +OpExecInfoPtr GenerateOpExecInfo(const py::args &args); const std::set ignore_infer_prim = {"partial", "make_ref"}; diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 6a1ddf6a7e4..0d18dfb5770 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -43,7 +43,7 @@ const std::unordered_set vm_operators = {"partial", "depend", "make namespace mindspore { namespace pynative { -inline ValuePtr PyAttrValue(const py::object& obj) { +inline ValuePtr PyAttrValue(const py::object &obj) { ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { @@ -52,11 +52,11 @@ inline ValuePtr PyAttrValue(const py::object& obj) { return converted_ret; } -py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) { +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) { auto signature = prim->signatures(); std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature& sig) { return sig.dtype; }); + [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { return py_args; @@ -103,7 +103,7 @@ py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) { return py_inputs; } -void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) { +void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecInfo *const op_exec_info) { size_t size = py_args.size(); AbstractBasePtrList args_spec_list; for (size_t i = 0; i < size; i++) { @@ -118,7 +118,7 @@ void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecI op_exec_info->abstract = infer_res; } -OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { +OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { if (args.size() != PY_ARGS_NUM) { MS_LOG(ERROR) << "Four args are needed by RunOp"; return nullptr; @@ -147,7 +147,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { return op_exec_info; } -std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) { +std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info) { MS_EXCEPTION_IF_NULL(op_exec_info); std::string graph_info; MS_EXCEPTION_IF_NULL(op_exec_info->abstract); @@ -167,7 +167,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) { return graph_info; } -py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_LOG(INFO) << "RunOpInVM start"; MS_EXCEPTION_IF_NULL(status); @@ -188,7 +188,7 @@ py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat return std::move(result); } -py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; auto ms_context = MsContext::GetInstance(); @@ -212,7 +212,7 @@ py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat } py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info, - PynativeStatusCode* const status) { + PynativeStatusCode *const status) { MS_EXCEPTION_IF_NULL(status); py::object result; switch (backend_policy) { @@ -248,7 +248,7 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn return result; } -py::tuple RunOp(const py::args& args) { +py::tuple RunOp(const py::args &args) { py::object result; // returns a null py::tuple on error py::tuple err_ret(0); diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h index 17b5610bfd0..c64c6b4b251 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pynative/pynative_execute.h @@ -33,9 +33,9 @@ namespace pynative { namespace py = pybind11; -py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status); +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); -py::tuple RunOp(const py::args& args); +py::tuple RunOp(const py::args &args); } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.cc b/mindspore/ccsrc/pynative/pynative_execute_ge.cc index 180b0006ffb..0bf2a391f94 100644 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.cc +++ b/mindspore/ccsrc/pynative/pynative_execute_ge.cc @@ -43,7 +43,7 @@ using transform::GraphRunner; using transform::GraphRunnerOptions; using transform::OperatorPtr; static std::shared_ptr session = nullptr; -inline ValuePtr PyAttrValue(const py::object& obj) { +inline ValuePtr PyAttrValue(const py::object &obj) { ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { @@ -52,7 +52,7 @@ inline ValuePtr PyAttrValue(const py::object& obj) { return converted_ret; } -MeTensorPtr ConvertPyObjToTensor(const py::object& obj) { +MeTensorPtr ConvertPyObjToTensor(const py::object &obj) { MeTensorPtr me_tensor_ptr = nullptr; if (py::isinstance(obj)) { me_tensor_ptr = py::cast(obj); @@ -72,8 +72,8 @@ MeTensorPtr ConvertPyObjToTensor(const py::object& obj) { return me_tensor_ptr; } -bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const OperatorPtr& op, std::vector* graph_input_nodes) { +bool SetInputsForSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const OperatorPtr &op, std::vector *graph_input_nodes) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(graph_input_nodes); auto op_inputs = op_exec_info->op_inputs; @@ -103,7 +103,7 @@ bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vec auto pointer_cast_const_op = std::static_pointer_cast(const_op); MS_EXCEPTION_IF_NULL(pointer_cast_const_op); (void)pointer_cast_const_op->update_output_desc_y(*const_op_desc); - auto& input_map = adapter->getInputMap(); + auto &input_map = adapter->getInputMap(); if (input_map.find(op_input_idx) == input_map.end()) { continue; } @@ -116,8 +116,8 @@ bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vec return true; } -bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const std::unordered_map& attrs, const GeGraphPtr& graph) { +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph) { MS_EXCEPTION_IF_NULL(op_exec_info); std::string op_name = op_exec_info->op_name; auto op_inputs = op_exec_info->op_inputs; @@ -145,8 +145,8 @@ bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vectorsetAttr(op, attr.first, attr.second); } // set input attributes - auto& input_attr_map = adapter->getInputAttrMap(); - for (auto& it : input_attr_map) { + auto &input_attr_map = adapter->getInputAttrMap(); + for (auto &it : input_attr_map) { if (op_inputs.size() < it.first) { continue; } @@ -165,7 +165,7 @@ bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector* const inputs) { +void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector *const inputs) { MS_EXCEPTION_IF_NULL(inputs); MS_EXCEPTION_IF_NULL(op_exec_info); auto op_inputs = op_exec_info->op_inputs; @@ -185,12 +185,12 @@ void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector* con } } -PynativeStatusCode ConvertAttributes(const OpExecInfoPtr& op_exec_info, const std::vector& inputs) { +PynativeStatusCode ConvertAttributes(const OpExecInfoPtr &op_exec_info, const std::vector &inputs) { MS_EXCEPTION_IF_NULL(op_exec_info); auto op_attrs = op_exec_info->op_attrs; std::unordered_map attrs{}; - for (auto& item : op_attrs) { + for (auto &item : op_attrs) { if (!py::isinstance(item.first)) { MS_LOG(ERROR) << "Type error in py dict convert"; return PYNATIVE_OP_ATTRS_ERR; @@ -218,8 +218,8 @@ PynativeStatusCode ConvertAttributes(const OpExecInfoPtr& op_exec_info, const st return PYNATIVE_SUCCESS; } -std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, - const std::vector& ge_tensors) { +std::vector ConvertOutputTensors(const OpExecInfoPtr &op_exec_info, + const std::vector &ge_tensors) { std::vector outputs; AbstractBasePtr abs_base = op_exec_info->abstract; std::vector> shapes; @@ -242,7 +242,7 @@ std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); return outputs; } - for (auto& it : ge_tensors) { + for (auto &it : ge_tensors) { auto tensor = transform::TransformUtil::ConvertGeTensor(it); if (tensor != nullptr) { outputs.emplace_back(tensor); @@ -251,7 +251,7 @@ std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, return outputs; } -py::object RunOpInGE(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_LOG(INFO) << "RunOpInGe start"; MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(status); diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.h b/mindspore/ccsrc/pynative/pynative_execute_ge.h index af0efec3e33..2dca3df0187 100644 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.h +++ b/mindspore/ccsrc/pynative/pynative_execute_ge.h @@ -36,10 +36,10 @@ using GeGraphPtr = std::shared_ptr; namespace mindspore { namespace pynative { -bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const std::unordered_map& attrs, const GeGraphPtr& graph); +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph); -py::object RunOpInGE(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status); +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/transform/convert.h b/mindspore/ccsrc/transform/convert.h index 556db5aceee..5596e20f196 100644 --- a/mindspore/ccsrc/transform/convert.h +++ b/mindspore/ccsrc/transform/convert.h @@ -51,16 +51,16 @@ class OpAdapterDesc { public: OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} - OpAdapterDesc(const OpAdapterPtr& train, const OpAdapterPtr& infer) : train_(train), infer_(infer) {} + OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} - explicit OpAdapterDesc(const OpAdapterPtr& common) : train_(common), infer_(common) {} + explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} - OpAdapterDesc(const OpAdapterDesc& desc) { + OpAdapterDesc(const OpAdapterDesc &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; } - OpAdapterDesc(OpAdapterDesc&& desc) { + OpAdapterDesc(OpAdapterDesc &&desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; desc.train_ = nullptr; @@ -71,7 +71,7 @@ class OpAdapterDesc { OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } - OpAdapterDesc& operator=(const OpAdapterDesc& desc) { + OpAdapterDesc &operator=(const OpAdapterDesc &desc) { if (this != &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; @@ -79,7 +79,7 @@ class OpAdapterDesc { return *this; } - OpAdapterDesc& operator=(OpAdapterDesc&& desc) { + OpAdapterDesc &operator=(OpAdapterDesc &&desc) { if (this != &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; @@ -99,7 +99,7 @@ using TensorOrderMap = std::map>; class DfGraphConvertor { public: - explicit DfGraphConvertor(const AnfGraphPtr& anf_graph) + explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph), df_graph_(std::make_shared(anf_graph_->ToString())) { #if (!defined ENABLE_GE) || (defined ENABLE_INFER) auto it_training = anf_graph->flags().find("training"); @@ -125,14 +125,14 @@ class DfGraphConvertor { ~DfGraphConvertor() {} - static void RegisterAdapter(const std::string& name, OpAdapterPtr adpt) { + static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt) { get_adpt_map()[name] = std::make_shared(adpt); } - static void RegisterAdapter(const std::string& name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { + static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { get_adpt_map()[name] = std::make_shared(train_adpt, infer_adpt); } - void DrawComputeGraph(const std::string& name) { + void DrawComputeGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -141,7 +141,7 @@ class DfGraphConvertor { fout << compute_sout_.str(); fout.close(); } - void DrawInitGraph(const std::string& name) { + void DrawInitGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -150,7 +150,7 @@ class DfGraphConvertor { fout << init_sout_.str(); fout.close(); } - void DrawSaveCheckpointGraph(const std::string& name) { + void DrawSaveCheckpointGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -160,74 +160,74 @@ class DfGraphConvertor { fout.close(); } - DfGraphConvertor& ConvertAllNode(); - DfGraphConvertor& BuildGraph(); - DfGraphConvertor& InitParam(const TensorOrderMap& tensors); - DfGraphConvertor& GenerateCheckpointGraph(); - DfGraphConvertor& GenerateBroadcastGraph(const TensorOrderMap& tensors); - void InitParamWithData(const TensorOrderMap& tensors); - void SetOpInput(const OpAdapterPtr& adpt, const CNodePtr& node); - void SetupBroadcast(const std::shared_ptr& broadcast, const std::vector& broadcast_desc, - const DfGraphPtr& broadcast_graph, std::vector broadcast_input); - void MakeDatasetHandler(const std::string& name, const size_t& input_idx, const AnfNodePtr& it); - void SetupParamInitSubGraph(const TensorOrderMap& tensors, std::vector* init_input); - void DrawParamInitSubGraph(const std::string& name, const AnfNodePtr& it); + DfGraphConvertor &ConvertAllNode(); + DfGraphConvertor &BuildGraph(); + DfGraphConvertor &InitParam(const TensorOrderMap &tensors); + DfGraphConvertor &GenerateCheckpointGraph(); + DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); + void InitParamWithData(const TensorOrderMap &tensors); + void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); + void SetupBroadcast(const std::shared_ptr &broadcast, const std::vector &broadcast_desc, + const DfGraphPtr &broadcast_graph, std::vector broadcast_input); + void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); + void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input); + void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); DfGraphPtr GetComputeGraph(); DfGraphPtr GetInitGraph(); DfGraphPtr GetSaveCheckpointGraph(); DfGraphPtr GetBroadcastGraph(); - static OpAdapterPtr FindAdapter(const std::string& op_name, bool train = false); + static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); int ErrCode() const { return static_cast(error_); } - static std::unordered_map& get_adpt_map(); + static std::unordered_map &get_adpt_map(); bool is_training() const { return training_; } void set_training(bool is_training) { training_ = is_training; } protected: - void InitLoopVar(std::vector* init_input); + void InitLoopVar(std::vector *init_input); private: std::ostringstream compute_sout_; std::ostringstream init_sout_; std::ostringstream checkpoint_sout_; std::ostringstream restore_checkpoint_sout_; - std::unordered_map op_draw_name_; + std::unordered_map op_draw_name_; - AnfNodePtr TraceTupleGetItem(const CNodePtr& node, unsigned int* index); - AnfNodePtr TraceMakeTuple(const CNodePtr& node, unsigned int index); - AnfNodePtr TraceDepend(const CNodePtr& node); + AnfNodePtr TraceTupleGetItem(const CNodePtr &node, unsigned int *index); + AnfNodePtr TraceMakeTuple(const CNodePtr &node, unsigned int index); + AnfNodePtr TraceDepend(const CNodePtr &node); OutHandler TraceRealOp(AnfNodePtr node); - OutHandler GetHandler(const AnfNodePtr& node, const std::stack& index_stack, AnfNode* const draw_index); + OutHandler GetHandler(const AnfNodePtr &node, const std::stack &index_stack, AnfNode *const draw_index); OperatorPtr Convert(AnfNodePtr node); OperatorPtr ConvertCNode(CNodePtr node); std::vector ConvertDependNode(AnfNodePtr node); AnfNodePtr GetRealOpNode(AnfNodePtr node); - std::vector GetDependNodes(const AnfNodePtr& node); + std::vector GetDependNodes(const AnfNodePtr &node); OperatorPtr ConvertParameter(AnfNodePtr node); Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node); void ConvertTupleGetItem(const CNodePtr node); - void GetDependOnParameterUse(const CNodePtr& node, const AnfNodePtr& src_node, const AnfNodePtr& dest_node, - const std::shared_ptr>& src_ops_list, - const std::shared_ptr>& dst_ops_list); - bool GetControlDependList(const CNodePtr& node, const std::shared_ptr>& src_ops_list, - const std::shared_ptr>& dst_ops_list); - void DrawControlDepend(const AnfNodePtr& src_node, const AnfNodePtr& dest_node); + void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); void ConvertControlDependNode(const CNodePtr node); void ConvertMakeTuple(const CNodePtr node); - bool CheckCNode(const std::string& name, const CNodePtr node); + bool CheckCNode(const std::string &name, const CNodePtr node); void TraceOutput(AnfNodePtr node); - void TraceOutputFromParameter(const AnfNodePtr& anf_out); - void TraceOutputFromTupleGetItem(const AnfNodePtr& anf_out); + void TraceOutputFromParameter(const AnfNodePtr &anf_out); + void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); void SetNodeInput(AnfNodePtr node); void SetOpControlInput(const AnfNodePtr node); void UpdateOpDesc(AnfNodePtr node); void BuildSaveCheckpointGraph(); void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); - void UpdateDataOpDesc(const AnfNodePtr& it, const OperatorPtr& op) const; - void AddGraphConstInput(const OperatorPtr& op); + void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; + void AddGraphConstInput(const OperatorPtr &op); std::shared_ptr anf_graph_{nullptr}; std::shared_ptr df_graph_{nullptr}; @@ -235,12 +235,12 @@ class DfGraphConvertor { std::shared_ptr save_ckp_graph_{nullptr}; std::shared_ptr restore_ckp_graph_{nullptr}; std::shared_ptr broadcast_graph_{nullptr}; - std::unordered_map op_cache_; - std::unordered_map> control_depend_cache_; + std::unordered_map op_cache_; + std::unordered_map> control_depend_cache_; /* record "tuple_getitem"<->"out_handler" mapping */ - std::unordered_map out_handle_cache_; + std::unordered_map out_handle_cache_; /* record "make_tuple"<->"out_handler vector" mapping */ - std::unordered_map>> tuple_out_handle_cache_; + std::unordered_map>> tuple_out_handle_cache_; std::unordered_map params_; std::unordered_map vars_; std::vector> graph_outputs_; diff --git a/mindspore/ccsrc/transform/df_graph_manager.cc b/mindspore/ccsrc/transform/df_graph_manager.cc index bfe4d9f5d23..f62c3865877 100644 --- a/mindspore/ccsrc/transform/df_graph_manager.cc +++ b/mindspore/ccsrc/transform/df_graph_manager.cc @@ -31,8 +31,8 @@ namespace mindspore { namespace transform { -DfGraphWrapper::DfGraphWrapper(const std::string& name, const int& id, const DfGraphPtr& graph_ptr, - const OptionMap& options) +DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, + const OptionMap &options) : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {} DfGraphManager::DfGraphManager() { @@ -49,7 +49,7 @@ DfGraphManager::~DfGraphManager() { parse::python_adapter::set_python_env_flag(false); } -DfGraphManager& DfGraphManager::GetInstance() { +DfGraphManager &DfGraphManager::GetInstance() { static DfGraphManager instance; return instance; } @@ -63,7 +63,7 @@ int DfGraphManager::GenerateId() { return graph_id_; } -Status DfGraphManager::AddGraph(const std::string& name, const DfGraphPtr& graph_ptr, const OptionMap& options) { +Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options) { std::lock_guard lg(lock_); if (name.empty()) { MS_LOG(ERROR) << "The graph name is null, add graph failed"; @@ -101,9 +101,9 @@ std::vector DfGraphManager::GetAllGraphs() { } std::set DfGraphManager::GetSavedGraphs() { return saved_graphs_; } -void DfGraphManager::AddSavedGraphs(const std::string& id) { saved_graphs_.insert(id); } +void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); } -DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string& name) { +DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) { std::lock_guard lg(lock_); if (name.empty()) { MS_LOG(ERROR) << "The graph name is null"; @@ -126,7 +126,7 @@ void DfGraphManager::ClearGraph() noexcept { MS_LOG(INFO) << "Remove all graphs in GraphManager"; } -void DfGraphManager::SetAnfGraph(const std::string& name, const AnfGraphPtr& anf_graph_ptr) { +void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) { DfGraphWrapperPtr df_graph = GetGraphByName(name); if (df_graph == nullptr) { MS_LOG(ERROR) << "Can't found graph name: " << name; @@ -152,7 +152,7 @@ void DfGraphManager::EraseAnfGraph() { anf_graphs_.clear(); } -void DfGraphManager::SetGeSession(const std::shared_ptr& sess_ptr) { +void DfGraphManager::SetGeSession(const std::shared_ptr &sess_ptr) { std::lock_guard lg(lock_); if (sess_ptr == nullptr) { MS_LOG(WARNING) << "You are adding a empty Ge Session"; @@ -182,7 +182,7 @@ void DfGraphManager::DeleteGeSession() noexcept { } } -void DfGraphManager::SetGraphRunner(const std::shared_ptr& graph_runner_ptr) noexcept { +void DfGraphManager::SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept { std::lock_guard lg(lock_); if (graph_runner_ptr == nullptr) { MS_LOG(WARNING) << "You are adding a empty GraphRunner"; diff --git a/mindspore/ccsrc/transform/df_graph_manager.h b/mindspore/ccsrc/transform/df_graph_manager.h index 97137ae94bb..2ca43d1f073 100644 --- a/mindspore/ccsrc/transform/df_graph_manager.h +++ b/mindspore/ccsrc/transform/df_graph_manager.h @@ -35,7 +35,7 @@ using OptionMap = std::map; struct DfGraphWrapper { public: - DfGraphWrapper(const std::string& name, const int& id, const DfGraphPtr& graph_ptr, const OptionMap& options); + DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options); ~DfGraphWrapper() {} std::string name_; @@ -51,19 +51,19 @@ class DfGraphManager { ~DfGraphManager(); void ClearGraph() noexcept; - static DfGraphManager& GetInstance(); - Status AddGraph(const std::string& name, const DfGraphPtr& graph, const OptionMap& options = {}); + static DfGraphManager &GetInstance(); + Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options = {}); std::vector GetAllGraphs(); std::set GetSavedGraphs(); - void AddSavedGraphs(const std::string& id); - DfGraphWrapperPtr GetGraphByName(const std::string& name); - DfGraphManager(const DfGraphManager&) = delete; - void SetAnfGraph(const std::string& name, const AnfGraphPtr& anf_graph_ptr); + void AddSavedGraphs(const std::string &id); + DfGraphWrapperPtr GetGraphByName(const std::string &name); + DfGraphManager(const DfGraphManager &) = delete; + void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr); AnfGraphPtr GetAnfGraph(uint32_t graph_id); std::shared_ptr GetGraphRunner(); - void SetGraphRunner(const std::shared_ptr& graph_runner_ptr) noexcept; + void SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept; void DeleteGraphRunner() noexcept; - void SetGeSession(const std::shared_ptr& sess_ptr); + void SetGeSession(const std::shared_ptr &sess_ptr); std::shared_ptr GetGeSession(); void DeleteGeSession() noexcept; void EraseAnfGraph(); diff --git a/mindspore/ccsrc/transform/graph_builder.cc b/mindspore/ccsrc/transform/graph_builder.cc index 9c05969fb03..785c5c7f3a0 100644 --- a/mindspore/ccsrc/transform/graph_builder.cc +++ b/mindspore/ccsrc/transform/graph_builder.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace transform { -DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam& param) { +DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam ¶m) { MS_LOG(INFO) << "BuildMDDatasetGraph."; // InitData @@ -37,7 +37,7 @@ DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam& param) { return dataset_graph; } -Status BuildDatasetGraph(const DatasetGraphParam& param, const std::string& phase) { +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase) { Status ret; std::string graph_name = phase; diff --git a/mindspore/ccsrc/transform/graph_builder.h b/mindspore/ccsrc/transform/graph_builder.h index 30b891460bb..3d959f5a85c 100644 --- a/mindspore/ccsrc/transform/graph_builder.h +++ b/mindspore/ccsrc/transform/graph_builder.h @@ -27,7 +27,7 @@ namespace mindspore { namespace transform { -Status BuildDatasetGraph(const DatasetGraphParam& param, const std::string& phase = "dataset"); +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase = "dataset"); } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_runner.cc b/mindspore/ccsrc/transform/graph_runner.cc index 8b0ddfd18d6..52d0d8e17fe 100644 --- a/mindspore/ccsrc/transform/graph_runner.cc +++ b/mindspore/ccsrc/transform/graph_runner.cc @@ -30,7 +30,7 @@ #ifdef NO_GE_CLIENT namespace ge { -Session::Session(const std::map& options) { +Session::Session(const std::map &options) { if (options.empty()) { MS_LOG(ERROR) << "session input options is empty"; } @@ -42,7 +42,7 @@ Session::~Session() {} namespace mindspore { namespace transform { -std::shared_ptr GraphRunner::NewSession(const SessionOptions& sess_options) { +std::shared_ptr GraphRunner::NewSession(const SessionOptions &sess_options) { std::shared_ptr ret = std::make_shared(sess_options); if (ret == nullptr) { MS_LOG(ERROR) << "Create GE session failed"; @@ -52,7 +52,7 @@ std::shared_ptr GraphRunner::NewSession(const SessionOptions& sess_ return ret; } -GraphRunner::GraphRunner(const GraphRunnerOptions& options) +GraphRunner::GraphRunner(const GraphRunnerOptions &options) : options_(options), graph_manager_(DfGraphManager::GetInstance()) { if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) { MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode"; @@ -88,7 +88,7 @@ GraphRunner::GraphRunner(const GraphRunnerOptions& options) } #ifdef ENABLE_GE - for (auto& it : wrappers) { + for (auto &it : wrappers) { std::set saved_graph = graph_manager_.GetSavedGraphs(); auto iter_find = saved_graph.find(std::to_string(it->id_)); if (iter_find != saved_graph.end()) { @@ -101,8 +101,8 @@ GraphRunner::GraphRunner(const GraphRunnerOptions& options) #endif } -Status GraphRunner::RunGraph(const RunOptions& options, const std::vector& inputs, - std::vector* outputs) { +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *outputs) { std::string name = options.name; if (name.empty()) { MS_LOG(ERROR) << "The graph name is null"; @@ -125,7 +125,7 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vector ge_outputs; (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs), - [](const GeTensorPtr& i) { return *i; }); + [](const GeTensorPtr &i) { return *i; }); MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs"; @@ -161,19 +161,19 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vector(ge_tensor); }); + [](const GeTensor &ge_tensor) { return std::make_shared(ge_tensor); }); return Status::SUCCESS; } -Status GraphRunner::RunGraph(const RunOptions& options, const std::vector& inputs, - std::vector* const outputs) { +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *const outputs) { std::vector ge_inputs; for (auto it : inputs) { MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize(); auto shape = (*it).shape(); std::string shape_str; - for (const auto& elem : shape) { + for (const auto &elem : shape) { shape_str += std::to_string(elem); shape_str += " "; } @@ -199,7 +199,7 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vectoremplace_back(tensor); diff --git a/mindspore/ccsrc/transform/graph_runner.h b/mindspore/ccsrc/transform/graph_runner.h index a9aa9fbc59a..728a1a25a25 100644 --- a/mindspore/ccsrc/transform/graph_runner.h +++ b/mindspore/ccsrc/transform/graph_runner.h @@ -46,16 +46,16 @@ struct RunOptions { class GraphRunner { public: - explicit GraphRunner(const GraphRunnerOptions& options); + explicit GraphRunner(const GraphRunnerOptions &options); ~GraphRunner() { sess_ = nullptr; } - Status RunGraph(const RunOptions& options, const std::vector& inputs, std::vector* outputs); - Status RunGraph(const RunOptions& options, const std::vector& inputs, std::vector* outputs); - static std::shared_ptr NewSession(const SessionOptions& sess_options); + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + static std::shared_ptr NewSession(const SessionOptions &sess_options); private: std::shared_ptr sess_; transform::GraphRunnerOptions options_; - DfGraphManager& graph_manager_; + DfGraphManager &graph_manager_; }; } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index 421e4c45690..2039dfa7d6c 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -26,17 +26,17 @@ #include "utils/utils.h" namespace mindspore { namespace transform { -static uint32_t CustomInferFunc(const Operator&) { return 0; } +static uint32_t CustomInferFunc(const Operator &) { return 0; } template class OpAdapter : public BaseOpAdapter { public: using OpType = T; OpAdapter() {} - explicit OpAdapter(const ExtraAttr& extra_attr) : extra_attr_(extra_attr) {} + explicit OpAdapter(const ExtraAttr &extra_attr) : extra_attr_(extra_attr) {} ~OpAdapter() override {} - bool IsCustomOp(const OperatorPtr& op) { + bool IsCustomOp(const OperatorPtr &op) { MS_EXCEPTION_IF_NULL(op); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { @@ -45,7 +45,7 @@ class OpAdapter : public BaseOpAdapter { return true; } - Status GenerateCustomOpInputMap(const CusOperatorPtr& op, const PrimitivePtr& prim) { + Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(prim); // Create the map of custom op from input index to input name. @@ -69,7 +69,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - Status GenerateCustomOpOutputMap(const CusOperatorPtr& op, const PrimitivePtr& prim) { + Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(prim); // Create the map of custom op from output index to output name. @@ -122,7 +122,7 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr GenerateNormalOp(const AnfNodePtr& anf) { + OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) { OperatorPtr op = nullptr; // There are duplicate names in ANF graph, do not assign ANF node name to GE // GE will generate unique name automatically @@ -148,7 +148,7 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr generate(const AnfNodePtr& anf) override { + OperatorPtr generate(const AnfNodePtr &anf) override { OperatorPtr op = nullptr; if (IsCustomCNode(anf)) { op = GenerateCustomOp(anf); @@ -158,21 +158,21 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr generate(const std::string& op_name) override { return std::make_shared(op_name); } + OperatorPtr generate(const std::string &op_name) override { return std::make_shared(op_name); } - const std::unordered_map& getInputMap() override { return input_map_; } - const std::unordered_map& getInputAttrMap() override { return input_attr_map_; } - const std::unordered_map& getDynInputMap() override { return dyn_input_map_; } - const std::unordered_map& getOutputMap() override { return output_map_; } + const std::unordered_map &getInputMap() override { return input_map_; } + const std::unordered_map &getInputAttrMap() override { return input_attr_map_; } + const std::unordered_map &getDynInputMap() override { return dyn_input_map_; } + const std::unordered_map &getOutputMap() override { return output_map_; } - Status SetCustomOpInput(const CusOperatorPtr& op, int index, const OperatorPtr& input) { + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(input); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { return NOT_FOUND; } - std::unordered_map& input_map = it->second; + std::unordered_map &input_map = it->second; if ((input_map.find(index) != input_map.end())) { MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index]; @@ -182,7 +182,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - Status SetNormalOpInput(const OperatorPtr& op, int index, const OperatorPtr& input) { + Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { MS_EXCEPTION_IF_NULL(op); auto it = input_map_.find(index); if (it != input_map_.end()) { @@ -194,7 +194,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - int setInput(const OperatorPtr& op, int index, const OperatorPtr& input) override { + int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { if (IsCustomOp(op)) { auto cus_op = std::dynamic_pointer_cast(op); return static_cast(SetCustomOpInput(cus_op, index, input)); @@ -203,14 +203,14 @@ class OpAdapter : public BaseOpAdapter { } } - Status SetCustomOpInput(const CusOperatorPtr& op, int index, const OutHandler& handle) { + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { MS_EXCEPTION_IF_NULL(op); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { return NOT_FOUND; } - std::unordered_map& input_map = it->second; + std::unordered_map &input_map = it->second; if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) { if (handle.out.empty()) { MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index]; @@ -225,7 +225,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - Status SetNormalOpInput(const OperatorPtr& op, int index, const OutHandler& handle) { + Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { MS_EXCEPTION_IF_NULL(op); auto it = input_map_.find(index); if ((handle.op != nullptr) && (it != input_map_.end())) { @@ -242,7 +242,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - int setInput(const OperatorPtr& op, int index, const OutHandler& handle) override { + int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { if (IsCustomOp(op)) { auto cus_op = std::dynamic_pointer_cast(op); return static_cast(SetCustomOpInput(cus_op, index, handle)); @@ -251,7 +251,7 @@ class OpAdapter : public BaseOpAdapter { } } - int setInput(const OperatorPtr& op, int index, const std::shared_ptr>& handler_vec) override { + int setInput(const OperatorPtr &op, int index, const std::shared_ptr> &handler_vec) override { MS_EXCEPTION_IF_NULL(handler_vec); if (IsCustomOp(op)) { MS_LOG(ERROR) << "Custom Op do not support dynamic input"; @@ -278,7 +278,7 @@ class OpAdapter : public BaseOpAdapter { return static_cast(NOT_FOUND); } - OutHandler getOutput(const OperatorPtr& op, int index) override { + OutHandler getOutput(const OperatorPtr &op, int index) override { MS_EXCEPTION_IF_NULL(op); if (IsCustomOp(op)) { return getCustomOutput(op, index); @@ -286,7 +286,7 @@ class OpAdapter : public BaseOpAdapter { return getNormalOutput(op, index); } - OutHandler getCustomOutput(const OperatorPtr& op, int index) { + OutHandler getCustomOutput(const OperatorPtr &op, int index) { MS_EXCEPTION_IF_NULL(op); auto it = cus_output_map_.find(op->GetOpType()); if (it == cus_output_map_.end()) { @@ -294,7 +294,7 @@ class OpAdapter : public BaseOpAdapter { return OutHandler(); } - std::unordered_map& output_map = it->second; + std::unordered_map &output_map = it->second; if ((output_map.find(index) != output_map.end())) { return OutHandler(op, output_map[index]); @@ -303,7 +303,7 @@ class OpAdapter : public BaseOpAdapter { return OutHandler(); } - OutHandler getNormalOutput(const OperatorPtr& op, int index) { + OutHandler getNormalOutput(const OperatorPtr &op, int index) { MS_EXCEPTION_IF_NULL(op); if (!dyn_output_map_.empty() && !output_map_.empty()) { MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!"; @@ -320,7 +320,7 @@ class OpAdapter : public BaseOpAdapter { } } - Status UpdateSingleOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) { + Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { MS_EXCEPTION_IF_NULL(type); std::string format = "NCHW"; if (op->GetOpType() == kExtractImagePatchesOpName) { @@ -353,7 +353,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - size_t GetCustomOpOutputSize(const CusOperatorPtr& cus_op) { + size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { MS_EXCEPTION_IF_NULL(cus_op); if (cus_output_map_.find(cus_op->GetOpType()) == cus_output_map_.end()) { MS_LOG(ERROR) << "This op does not create custom output map"; @@ -363,8 +363,8 @@ class OpAdapter : public BaseOpAdapter { return output_size; } - std::shared_ptr CreateOutputDesc(const abstract::ShapePtr& shape_ptr, const TypePtr& type, - const std::string& format) { + std::shared_ptr CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, + const std::string &format) { if (shape_ptr == nullptr) { MS_LOG(ERROR) << "Shape ptr is nullptr"; return nullptr; @@ -383,7 +383,7 @@ class OpAdapter : public BaseOpAdapter { return desc; } - Status UpdateMultiOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) { + Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { auto tuple_shp = dyn_cast(shp); MS_EXCEPTION_IF_NULL(tuple_shp); @@ -432,7 +432,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - std::shared_ptr CreateNodeDesc(const AnfNodePtr& node) { + std::shared_ptr CreateNodeDesc(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); TypeId me_type = node->Type()->type_id(); if (kObjectTypeTensorType == me_type) { @@ -456,7 +456,7 @@ class OpAdapter : public BaseOpAdapter { return desc; } - void UpdateNormalOpInputDesc(const OperatorPtr& op, const AnfNodePtr node) { + void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -479,7 +479,7 @@ class OpAdapter : public BaseOpAdapter { } } - void UpdateCustomOpInputDesc(const CusOperatorPtr& op, const AnfNodePtr& node) { + void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -491,7 +491,7 @@ class OpAdapter : public BaseOpAdapter { return; } - std::unordered_map& input_map = cus_input_map_[op->GetOpType()]; + std::unordered_map &input_map = cus_input_map_[op->GetOpType()]; auto inputs = node->cast()->inputs(); for (size_t i = 1; i < inputs.size(); ++i) { if (input_map.find(i) != input_map.end()) { @@ -504,7 +504,7 @@ class OpAdapter : public BaseOpAdapter { } } - void updateInputDesc(const OperatorPtr& op, const AnfNodePtr& node) { + void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(node); if (IsCustomOp(op)) { @@ -515,8 +515,8 @@ class OpAdapter : public BaseOpAdapter { } } - void updateOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type, - const AnfNodePtr& node) override { + void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) override { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -548,7 +548,7 @@ class OpAdapter : public BaseOpAdapter { updateInputDesc(op, node); } - int setAttr(const OperatorPtr& op, const std::string& attrKey, const ValuePtr& attrValue) override { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { auto it = attr_map_.find(attrKey); if (it != attr_map_.end()) { // switch case for each avalilable attribute type @@ -560,7 +560,7 @@ class OpAdapter : public BaseOpAdapter { return static_cast(NOT_FOUND); } - int SetCustomOpAttr(const CusOperatorPtr& op, const PrimitivePtr& prim) { + int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { enum ValueType { SINGLE_VALUE = 0, SEQUEUE_VALUE, @@ -611,11 +611,11 @@ class OpAdapter : public BaseOpAdapter { return 0; } - int SetNormalOpAttr(const OperatorPtr& op, const PrimitivePtr& prim) { + int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { int ret = 0; MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(op); - for (auto& it : attr_map_) { + for (auto &it : attr_map_) { auto value = prim->GetAttr(it.first); if (value != nullptr) { // set attr from primitive @@ -637,7 +637,7 @@ class OpAdapter : public BaseOpAdapter { return 0; } - int setAttr(const OperatorPtr& op, const PrimitivePtr& prim) override { + int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { int ret = 0; if (IsCustomPrim(prim)) { auto cus_op = std::dynamic_pointer_cast(op); @@ -648,7 +648,7 @@ class OpAdapter : public BaseOpAdapter { return ret; } - int setAttr(const OperatorPtr& op, const AnfNodePtr& node) override { + int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { // no attribute for lonely node MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { @@ -660,7 +660,7 @@ class OpAdapter : public BaseOpAdapter { return 0; } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); if (inputs.empty()) { return 0; } @@ -691,7 +691,7 @@ class OpAdapter : public BaseOpAdapter { } // set attr from const input - for (auto& it : input_attr_map_) { + for (auto &it : input_attr_map_) { if (inputs.size() <= it.first || !inputs[it.first]->isa()) { continue; } @@ -711,38 +711,38 @@ class OpAdapter : public BaseOpAdapter { private: template - static S ConvertAny(const ValuePtr& value, const AnyTraits&) { + static S ConvertAny(const ValuePtr &value, const AnyTraits &) { return GetValue(value); } // specialization for reverse bool - static bool ConvertAny(const ValuePtr& value, const AnyTraits&, bool reverse) { + static bool ConvertAny(const ValuePtr &value, const AnyTraits &, bool reverse) { return reverse != GetValue(value); } template - static Q ConvertAny(const ValuePtr& value, const AnyTraits

& traits_from, const AnyTraits& traits_to) { + static Q ConvertAny(const ValuePtr &value, const AnyTraits

&traits_from, const AnyTraits &traits_to) { return ConvertAnyUtil(value, traits_from, traits_to); } // specialization for tensor - static GeTensor ConvertAny(const ValuePtr& value, const AnyTraits& traits) { + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits &traits) { // To-DO the format may read from ME tensor return ConvertAnyUtil(value, traits); } // specialization for int - static int64_t ConvertAny(const ValuePtr& value, const AnyTraits) { + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { return static_cast(GetValue(value)); } // specialization for int to Vector - static std::vector ConvertAny(const ValuePtr& value, const std::string& name, + static std::vector ConvertAny(const ValuePtr &value, const std::string &name, const AnyTraits> anyTraitsInt) { return ConvertAnyUtil(value, name, anyTraitsInt); } - static std::vector> ConvertAny(const ValuePtr& value, + static std::vector> ConvertAny(const ValuePtr &value, const AnyTraits>>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(INFO) << "Value: " << value->type_name(); @@ -752,14 +752,14 @@ class OpAdapter : public BaseOpAdapter { } auto vec = value->cast(); MS_EXCEPTION_IF_NULL(vec); - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { MS_EXCEPTION_IF_NULL(it); if (!it->isa()) { MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name(); } auto sub_vector = it->cast(); std::vector sublist; - for (auto& item : sub_vector->value()) { + for (auto &item : sub_vector->value()) { sublist.push_back(static_cast(GetValue(item))); } list.push_back(sublist); @@ -767,7 +767,7 @@ class OpAdapter : public BaseOpAdapter { return list; } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits>>, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>>, const AnyTraits>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(DEBUG) << "Value: " << value->type_name(); @@ -776,20 +776,20 @@ class OpAdapter : public BaseOpAdapter { } auto vec = value->cast(); std::vector list; - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { MS_EXCEPTION_IF_NULL(it); if (!it->isa()) { MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name(); } auto sub_vector = it->cast(); - for (auto& item : sub_vector->value()) { + for (auto &item : sub_vector->value()) { list.push_back(static_cast(GetValue(item))); } } return list; } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits>, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>, const AnyTraits>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(INFO) << "Value: " << value->type_name(); @@ -797,7 +797,7 @@ class OpAdapter : public BaseOpAdapter { if (value->isa()) { auto vec = value->cast(); MS_EXCEPTION_IF_NULL(vec); - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { list.push_back(static_cast(GetValue(it))); } return list; @@ -809,17 +809,17 @@ class OpAdapter : public BaseOpAdapter { MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); } - static std::string ConvertAny(const ValuePtr& value, const AnyTraits> anyTraitsVec, + static std::string ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsStr) { return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits> anyTraitsVec, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsFlo) { return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); } - static std::vector ConvertAny(const ValuePtr& value, const std::string& format, + static std::vector ConvertAny(const ValuePtr &value, const std::string &format, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsInt) { return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); @@ -827,12 +827,12 @@ class OpAdapter : public BaseOpAdapter { // convert value list for value tuple to vector template - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits

& anyTraitsP, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits

&anyTraitsP, const AnyTraits> anyTraitsQ) { return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); } - static int64_t ConvertAny(const ValuePtr& value, const AnyTraits) { + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { auto name = GetValue(value); auto it = enum_map_.find(name); int v = 0; @@ -842,12 +842,12 @@ class OpAdapter : public BaseOpAdapter { return v; } - static GeDataType ConvertAny(const ValuePtr& value, const AnyTraits anyTraitsGE) { + static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsGE) { return ConvertAnyUtil(value, anyTraitsGE); } // convert any value to tensor - static GeTensor ConvertAny(const ValuePtr& value, const AnyTraits anyTraitsValue) { + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsValue) { return ConvertAnyUtil(value, anyTraitsValue); } diff --git a/mindspore/ccsrc/transform/op_adapter_base.h b/mindspore/ccsrc/transform/op_adapter_base.h index 99106b87611..01f96e251db 100644 --- a/mindspore/ccsrc/transform/op_adapter_base.h +++ b/mindspore/ccsrc/transform/op_adapter_base.h @@ -48,15 +48,17 @@ namespace ge { class CustomOperator : public Operator { public: - CustomOperator(const string& name, const string& type) : Operator(name, type) {} + CustomOperator(const string &name, const string &type) : Operator(name, type) {} ~CustomOperator() override{}; - void CustomInputRegister(const string& name) { Operator::InputRegister(name); } + void CustomInputRegister(const string &name) { Operator::InputRegister(name); } - void CustomOutputRegister(const string& name) { Operator::OutputRegister(name); } + void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } - void CustomInferFuncRegister(const std::function& func) { Operator::InferFuncRegister(func); } + void CustomInferFuncRegister(const std::function &func) { + Operator::InferFuncRegister(func); + } }; } // namespace ge @@ -69,7 +71,7 @@ struct OutHandler { OperatorPtr op; std::string out; OutHandler() : op(nullptr), out("") {} - OutHandler(const OperatorPtr& op, const std::string out) : op(op), out(out) {} + OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} }; struct ControlEdge { @@ -119,33 +121,33 @@ struct DynOutputDesc { class BaseOpAdapter { public: virtual ~BaseOpAdapter() {} - virtual OperatorPtr generate(const AnfNodePtr& anf) = 0; - virtual OperatorPtr generate(const std::string& type) { return std::make_shared(type); } - virtual int setInput(const OperatorPtr& op, int index, const OperatorPtr& input) = 0; - virtual int setInput(const OperatorPtr& op, int index, const OutHandler& handle) = 0; - virtual int setInput(const OperatorPtr& op, int index, - const std::shared_ptr>& handler_vec) = 0; - virtual int setAttr(const OperatorPtr& op, const std::string& attrKey, const ValuePtr& attrValue) = 0; - virtual int setAttr(const OperatorPtr& op, const PrimitivePtr& prim) = 0; - virtual int setAttr(const OperatorPtr& op, const AnfNodePtr& node) = 0; + virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; + virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } + virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; + virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; + virtual int setInput(const OperatorPtr &op, int index, + const std::shared_ptr> &handler_vec) = 0; + virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; + virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; + virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; virtual std::unordered_map GetExtraAttr() = 0; template ::value>::type> - int setAttr(const OperatorPtr& op, const std::string& attrKey, const std::shared_ptr& attrValue) { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr &attrValue) { return setAttr(op, attrKey, MakeValue(attrValue)); } template ::value>::type> - int setAttr(const OperatorPtr& op, const std::string& attrKey, const T& attrValue) { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { return setAttr(op, attrKey, MakeValue(attrValue)); } - virtual OutHandler getOutput(const OperatorPtr& op, int index) = 0; - virtual void updateOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type, - const AnfNodePtr& node) = 0; - virtual const std::unordered_map& getInputMap() = 0; - virtual const std::unordered_map& getInputAttrMap() = 0; - virtual const std::unordered_map& getDynInputMap() = 0; - virtual const std::unordered_map& getOutputMap() = 0; - void AddAttrToDrawGraph(const std::string& attr_str) { attrs_vec_.push_back(attr_str); } - const std::vector& GetAttrsFromDrawGraph() const { return attrs_vec_; } + virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; + virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) = 0; + virtual const std::unordered_map &getInputMap() = 0; + virtual const std::unordered_map &getInputAttrMap() = 0; + virtual const std::unordered_map &getDynInputMap() = 0; + virtual const std::unordered_map &getOutputMap() = 0; + void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } + const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } void clearAttrVect() { attrs_vec_.clear(); } private: diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc index d52699fa8fe..0163b80f083 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace transform { -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits&) { +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &) { // To-DO the format may read from ME tensor MS_EXCEPTION_IF_NULL(value); auto me_tensor = value->cast(); @@ -33,7 +33,7 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits ConvertAnyUtil(const ValuePtr& value, const std::string& name, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>) { int64_t data = GetValue(value); std::vector list; @@ -50,7 +50,7 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& na return list; } -std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits) { +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); if (nullptr == vec) { @@ -58,7 +58,7 @@ std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraitsvalue()) { + for (auto &it : vec->value()) { if (i != 0) { buffer << ","; } @@ -68,7 +68,7 @@ std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits) { +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); if (nullptr == vec) { @@ -77,11 +77,11 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const AnyTraits list; list.resize(vec->value().size()); (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr& val) { return static_cast(GetValue(val)); }); + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); return list; } -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& format, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); @@ -91,7 +91,7 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& fo std::vector list; list.resize(vec->value().size()); (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr& val) { return static_cast(GetValue(val)); }); + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); if (format == kOpFormat_NHWC) { if (list.size() < 4) { MS_LOG(EXCEPTION) << "The size of list is less than 4"; @@ -105,7 +105,7 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& fo return list; } -GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); if (!value->isa()) { MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString() @@ -120,7 +120,7 @@ GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { return TransformUtil::ConvertDataType(me_type); } -GeTensor VectorToTensorUtil(const ValuePtr& value) { +GeTensor VectorToTensorUtil(const ValuePtr &value) { // convert tuple or list to ge tensor, only supported one dim for now MS_EXCEPTION_IF_NULL(value); auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); @@ -136,7 +136,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); } else if (vec[0]->isa()) { MS_LOG(INFO) << "convert value to tensor with data type = Float32"; auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); @@ -144,7 +144,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); } else if (vec[0]->isa()) { MS_LOG(INFO) << "convert value to tensor with data type = Bool"; // We use uint8_t to save bool type data @@ -153,7 +153,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); + return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); } else { MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name(); } @@ -161,7 +161,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { return GeTensor(); } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); if (value->isa()) { // convert me tensor to ge tensor @@ -174,28 +174,28 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); } else if (value->isa()) { // convert scalar Int64 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = Int64"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); } else if (value->isa()) { // convert scalar FP32 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = FP32"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); } else if (value->isa()) { // convert scalar FP32 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = Bool"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); } else if (value->isa()) { // convert String to GeTensor MS_LOG(INFO) << "convert string to tensor with data type = String"; @@ -213,7 +213,7 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { return GeTensor(); } -bool IsCustomPrim(const PrimitivePtr& prim) { +bool IsCustomPrim(const PrimitivePtr &prim) { if (prim == nullptr) { return false; } @@ -232,7 +232,7 @@ bool IsCustomPrim(const PrimitivePtr& prim) { return is_custom_op; } -bool IsCustomCNode(const AnfNodePtr& anf) { +bool IsCustomCNode(const AnfNodePtr &anf) { if (anf == nullptr) { return false; } diff --git a/mindspore/ccsrc/transform/op_adapter_util.h b/mindspore/ccsrc/transform/op_adapter_util.h index 0cb6c763b2f..fcabc732d58 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.h +++ b/mindspore/ccsrc/transform/op_adapter_util.h @@ -25,42 +25,42 @@ namespace mindspore { namespace transform { template -static Q ConvertAnyUtil(const ValuePtr& value, const AnyTraits

&, const AnyTraits&) { +static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits

&, const AnyTraits &) { return static_cast(GetValue

(value)); } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits& traits); +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &traits); -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& name, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>); -std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits); +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); -std::vector ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits); +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& format, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, const AnyTraits>, const AnyTraits); -GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits); +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits); template -std::vector ConvertAnyUtil(const ValuePtr& value, AnyTraits

, const AnyTraits>) { +std::vector ConvertAnyUtil(const ValuePtr &value, AnyTraits

, const AnyTraits>) { if (!value->isa() && !value->isa()) { MS_LOG(EXCEPTION) << "error convert Value to vector for value: " << value->ToString() << ", type: " << value->type_name() << ", value should be a tuple or list"; } auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); std::vector data; - for (auto& it : vec) { + for (auto &it : vec) { data.push_back(ConvertAnyUtil(it, AnyTraits

(), AnyTraits())); } return data; } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits); +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits); -bool IsCustomPrim(const PrimitivePtr& prim); -bool IsCustomCNode(const AnfNodePtr& node); +bool IsCustomPrim(const PrimitivePtr &prim); +bool IsCustomCNode(const AnfNodePtr &node); } // namespace transform } // namespace mindspore #endif // TRANSFORM_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/util.cc b/mindspore/ccsrc/transform/util.cc index 0a18763d120..b1120ade6d1 100644 --- a/mindspore/ccsrc/transform/util.cc +++ b/mindspore/ccsrc/transform/util.cc @@ -53,7 +53,7 @@ static std::map datatype_trans_map = { {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32}, {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}}; -GeDataType TransformUtil::ConvertDataType(const MeDataType& type) { +GeDataType TransformUtil::ConvertDataType(const MeDataType &type) { MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type"; if (datatype_trans_map.find(type) != datatype_trans_map.end()) { return datatype_trans_map[type]; @@ -70,7 +70,7 @@ static std::map datatype_size_map = { {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)}, {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}}; -size_t TransformUtil::GetDataTypeSize(const MeDataType& type) { +size_t TransformUtil::GetDataTypeSize(const MeDataType &type) { if (datatype_size_map.find(type) != datatype_size_map.end()) { return datatype_size_map[type]; } else { @@ -79,7 +79,7 @@ size_t TransformUtil::GetDataTypeSize(const MeDataType& type) { } } -GeFormat TransformUtil::ConvertFormat(const string& format) { +GeFormat TransformUtil::ConvertFormat(const string &format) { if (format == kOpFormat_NCHW) { return GeFormat::FORMAT_NCHW; } else if (format == kOpFormat_NC1HWC0) { @@ -95,8 +95,8 @@ GeFormat TransformUtil::ConvertFormat(const string& format) { static int64_t IntegerCastFunc(size_t temp) { return static_cast(temp); } -std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector& me_shape, - const MeDataType& me_type, const std::string& format) { +std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector &me_shape, + const MeDataType &me_type, const std::string &format) { // convert me shape to ge shape std::vector ge_shape; @@ -135,8 +135,8 @@ std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector TransformUtil::ConvertInputTensors(const std::vector& me_tensors, - const std::string& format) { +std::vector TransformUtil::ConvertInputTensors(const std::vector &me_tensors, + const std::string &format) { std::vector ge_tensors; for (size_t index = 0; index < me_tensors.size(); index++) { @@ -163,7 +163,7 @@ std::vector TransformUtil::ConvertInputTensors(const std::vectordata_type()); @@ -192,15 +192,15 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr& tensor, const std::s MS_LOG(ERROR) << "Failed to get Tensor Desc"; return nullptr; } - GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); + GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); if (tensor_ptr != nullptr) { MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!"; } return tensor_ptr; } -std::vector TransformUtil::ConvertGeTensors(const std::vector& ge_tensors, - const std::vector>& request_dims) { +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims) { std::vector outputs; for (size_t index = 0; index < ge_tensors.size(); index++) { @@ -222,7 +222,7 @@ std::vector TransformUtil::ConvertGeTensors(const std::vector TransformUtil::ConvertGeTensors(const std::vector& ge_tensors) { +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors) { std::vector outputs; for (size_t index = 0; index < ge_tensors.size(); index++) { @@ -237,7 +237,7 @@ std::vector TransformUtil::ConvertGeTensors(const std::vector& request_dims) { +bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &request_dims) { MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims()); MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims); @@ -311,20 +311,20 @@ bool IsGeShapeCompatible(const GeShape& ge_shape, const std::vector& reques } } // namespace -GeShape TransformUtil::ConvertMeShape(const std::vector& me_dims) { +GeShape TransformUtil::ConvertMeShape(const std::vector &me_dims) { std::vector ge_dims; (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims)); return GeShape(ge_dims); } -std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape) { +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { std::vector me_dims; std::vector ge_dims = ge_shape.GetDims(); (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims)); return me_dims; } -std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape, const std::vector& request_dims) { +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims) { vector ret; if (ge_shape.GetDimNum() == 0) { MS_LOG(DEBUG) << "GeTensor's shape is scalar"; @@ -340,12 +340,12 @@ std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape, const st return ret; } -MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr& ge_tensor, const std::vector& me_dims, - const TypeId& me_type) { +MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type) { MeTensor me_tensor(me_type, me_dims); // Get the writable data pointer of the tensor and cast it to its data type - auto me_data_ptr = reinterpret_cast(me_tensor.data_c(true)); + auto me_data_ptr = reinterpret_cast(me_tensor.data_c(true)); size_t me_data_size = static_cast(me_tensor.data().nbytes()); MS_EXCEPTION_IF_NULL(me_data_ptr); MS_EXCEPTION_IF_NULL(ge_tensor); @@ -369,7 +369,7 @@ MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr& ge_tensor, const return make_shared(me_tensor); } -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr& ge_tensor) { +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) { MS_EXCEPTION_IF_NULL(ge_tensor); GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); vector me_dims = ConvertGeShape(ge_shape); @@ -384,7 +384,7 @@ MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr& ge_tensor) { } // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector& request_dims) { +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector &request_dims) { MS_EXCEPTION_IF_NULL(ge_tensor); GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); vector me_dims = ConvertGeShape(ge_shape, request_dims); diff --git a/mindspore/ccsrc/transform/util.h b/mindspore/ccsrc/transform/util.h index 9bcd8dc115d..0f5d79f6a19 100644 --- a/mindspore/ccsrc/transform/util.h +++ b/mindspore/ccsrc/transform/util.h @@ -47,7 +47,7 @@ class TransformUtil { * Return: * [GeDataType] the data type for ge tensor * */ - static GeDataType ConvertDataType(const MeDataType& type); + static GeDataType ConvertDataType(const MeDataType &type); /* * Parameters: @@ -55,7 +55,7 @@ class TransformUtil { * Return: * [GeFormat] the data format for ge tensor * */ - static GeFormat ConvertFormat(const std::string& format); + static GeFormat ConvertFormat(const std::string &format); /* * Parameters: @@ -63,7 +63,7 @@ class TransformUtil { * Return: * [size_t] the buff size for the type in ME * */ - static size_t GetDataTypeSize(const MeDataType& type); + static size_t GetDataTypeSize(const MeDataType &type); /* * Parameters: @@ -73,8 +73,8 @@ class TransformUtil { * Return: * [shared_ptr] the shared pointer of ge tensor description * */ - static std::shared_ptr GetGeTensorDesc(const std::vector& shape, const MeDataType& me_type, - const std::string& format); + static std::shared_ptr GetGeTensorDesc(const std::vector &shape, const MeDataType &me_type, + const std::string &format); /* * Parameters: @@ -84,7 +84,7 @@ class TransformUtil { * Return: * [GeTensor] the data tensor in GE * */ - static GeTensorPtr ConvertTensor(const MeTensorPtr& tensor, const std::string& format); + static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format); /* * Parameters: @@ -93,8 +93,8 @@ class TransformUtil { * Return: * [std::vector] the data tensors in GE * */ - static std::vector ConvertInputTensors(const std::vector& me_tensors, - const std::string& format); + static std::vector ConvertInputTensors(const std::vector &me_tensors, + const std::string &format); /* * Parameters: @@ -102,7 +102,7 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr ConvertGeTensor(const GeTensorPtr& tensor); + static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor); /* * Parameters: @@ -111,7 +111,7 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector& request_dims); + static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector &request_dims); /* * Parameters: * ge_tensors: [std::vector] the data tensor in GE @@ -119,15 +119,15 @@ class TransformUtil { * Return: * [std::vector] the data tensor in ME * */ - static std::vector ConvertGeTensors(const std::vector& ge_tensors, - const std::vector>& request_dims); + static std::vector ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims); /* * Parameters: * ge_tensors: [std::vector] the data tensor in GE * Return: * [std::vector] the data tensor in ME * */ - static std::vector ConvertGeTensors(const std::vector& ge_tensors); + static std::vector ConvertGeTensors(const std::vector &ge_tensors); /* * Parameters: * ge_tensor: [GeTensor] the data tensor in GE @@ -136,15 +136,15 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr GenerateMeTensor(const GeTensorPtr& ge_tensor, const std::vector& me_dims, - const TypeId& me_type); + static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type); /* * Parameters: * type: [GeDataType] the ge tensor data type * Return: * [MeDataType] the me tensor data type * */ - static MeDataType ConvertGeDataType(const GeDataType& type); + static MeDataType ConvertGeDataType(const GeDataType &type); /* * Parameters: @@ -152,7 +152,7 @@ class TransformUtil { * Return: * [GeShape] the ge shape * */ - static GeShape ConvertMeShape(const std::vector& me_dims); + static GeShape ConvertMeShape(const std::vector &me_dims); /* * Parameters: @@ -160,7 +160,7 @@ class TransformUtil { * Return: * [vector] the me shape * */ - static std::vector ConvertGeShape(const GeShape& ge_shape); + static std::vector ConvertGeShape(const GeShape &ge_shape); /* Function: * Convert GeShape to Me request shape, Support pattern: @@ -176,7 +176,7 @@ class TransformUtil { * Return: * [vector] the me shape * */ - static std::vector ConvertGeShape(const GeShape& ge_shape, const std::vector& request_dims); + static std::vector ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims); /* * Parameters: @@ -185,7 +185,7 @@ class TransformUtil { * [string] value string * */ template ::value>::type> - static std::string PrintVector(const std::vector& vec) { + static std::string PrintVector(const std::vector &vec) { const int MAX_PRINT_NUM = 100; std::stringstream ss; ss << "{ "; @@ -222,7 +222,7 @@ class TransformUtil { * [shared_ptr] vector pointer * */ template ::value>::type> - static std::vector MakeVector(const uint8_t* const data, size_t size) { + static std::vector MakeVector(const uint8_t *const data, size_t size) { auto dest = std::vector(size / sizeof(T)); if (data == nullptr) { return dest; diff --git a/mindspore/ccsrc/utils/any.cc b/mindspore/ccsrc/utils/any.cc index 31ee1fd3021..3cb89f5dd7f 100644 --- a/mindspore/ccsrc/utils/any.cc +++ b/mindspore/ccsrc/utils/any.cc @@ -21,7 +21,7 @@ namespace mindspore { // only support (int, float, bool) as Literal -bool AnyIsLiteral(const Any& any) { +bool AnyIsLiteral(const Any &any) { static const std::type_index typeid_int = std::type_index(typeid(int)); static const std::type_index typeid_float = std::type_index(typeid(float)); static const std::type_index typeid_bool = std::type_index(typeid(bool)); @@ -30,12 +30,12 @@ bool AnyIsLiteral(const Any& any) { return typeid_int == typeid_any || typeid_float == typeid_any || typeid_bool == typeid_any; } -std::ostream& operator<<(std::ostream& os, const pybind11::object& obj) { +std::ostream &operator<<(std::ostream &os, const pybind11::object &obj) { os << "[py::object]"; return os; } -Any& Any::operator=(const Any& other) { +Any &Any::operator=(const Any &other) { if (m_ptr == other.m_ptr || &other == this) { return *this; } @@ -44,9 +44,9 @@ Any& Any::operator=(const Any& other) { return *this; } -bool Any::operator<(const Any& other) const { return this < &other; } +bool Any::operator<(const Any &other) const { return this < &other; } -Any& Any::operator=(Any&& other) { +Any &Any::operator=(Any &&other) { if (this != &other) { if (m_ptr == other.m_ptr || &other == this) { return *this; diff --git a/mindspore/ccsrc/utils/any.h b/mindspore/ccsrc/utils/any.h index ce691f1c122..b4edf602ac2 100644 --- a/mindspore/ccsrc/utils/any.h +++ b/mindspore/ccsrc/utils/any.h @@ -35,23 +35,23 @@ namespace mindspore { // usage:AnyPtr sp = std::make_shared(aname); template -std::string type(const T& t) { +std::string type(const T &t) { return demangle(typeid(t).name()); } -std::ostream& operator<<(std::ostream& os, const pybind11::object& obj); +std::ostream &operator<<(std::ostream &os, const pybind11::object &obj); class Any { public: // constructors Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {} - Any(const Any& other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {} - Any(Any&& other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {} + Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {} + Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {} - Any& operator=(Any&& other); + Any &operator=(Any &&other); // right reference constructor template ::type, Any>::value, T>::type> - Any(T&& t) : m_tpIndex(typeid(typename std::decay::type)) { // NOLINT + Any(T &&t) : m_tpIndex(typeid(typename std::decay::type)) { // NOLINT BasePtr new_val(new Derived::type>(std::forward(t))); std::swap(m_ptr, new_val); } @@ -67,7 +67,7 @@ class Any { return m_tpIndex == std::type_index(typeid(T)); } - const std::type_info& type() const { return m_ptr ? m_ptr->type() : typeid(void); } + const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); } std::size_t Hash() const { std::stringstream buffer; @@ -79,7 +79,7 @@ class Any { } template - bool Apply(const std::function& fn) { + bool Apply(const std::function &fn) { if (type() == typeid(T)) { T x = cast(); fn(x); @@ -96,23 +96,23 @@ class Any { } } - friend std::ostream& operator<<(std::ostream& os, const Any& any) { + friend std::ostream &operator<<(std::ostream &os, const Any &any) { os << any.GetString(); return os; } // type cast template - T& cast() const { + T &cast() const { if (!is() || !m_ptr) { // Use MS_LOGFATAL replace throw std::bad_cast() MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name(); } - auto ptr = static_cast*>(m_ptr.get()); + auto ptr = static_cast *>(m_ptr.get()); return ptr->m_value; } - bool operator==(const Any& other) const { + bool operator==(const Any &other) const { if (m_tpIndex != other.m_tpIndex) { return false; } @@ -125,11 +125,11 @@ class Any { return *m_ptr == *other.m_ptr; } - bool operator!=(const Any& other) const { return !(operator==(other)); } + bool operator!=(const Any &other) const { return !(operator==(other)); } - Any& operator=(const Any& other); + Any &operator=(const Any &other); - bool operator<(const Any& other) const; + bool operator<(const Any &other) const; std::string ToString() const { std::ostringstream buffer; @@ -154,26 +154,26 @@ class Any { // type base definition struct Base { - virtual const std::type_info& type() const = 0; + virtual const std::type_info &type() const = 0; virtual BasePtr clone() const = 0; virtual ~Base() = default; - virtual bool operator==(const Base& other) const = 0; + virtual bool operator==(const Base &other) const = 0; virtual std::string GetString() = 0; }; template struct Derived : public Base { template - explicit Derived(Args&&... args) : m_value(std::forward(args)...), serialize_cache_("") {} + explicit Derived(Args &&... args) : m_value(std::forward(args)...), serialize_cache_("") {} - bool operator==(const Base& other) const override { + bool operator==(const Base &other) const override { if (typeid(*this) != typeid(other)) { return false; } - return m_value == static_cast&>(other).m_value; + return m_value == static_cast &>(other).m_value; } - const std::type_info& type() const override { return typeid(T); } + const std::type_info &type() const override { return typeid(T); } BasePtr clone() const override { return BasePtr(new Derived(m_value)); } @@ -204,14 +204,14 @@ class Any { using AnyPtr = std::shared_ptr; struct AnyHash { - std::size_t operator()(const Any& c) const { return c.Hash(); } + std::size_t operator()(const Any &c) const { return c.Hash(); } }; struct AnyLess { - bool operator()(const Any& a, const Any& b) const { return a.Hash() < b.Hash(); } + bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); } }; -bool AnyIsLiteral(const Any& any); +bool AnyIsLiteral(const Any &any); } // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref.cc b/mindspore/ccsrc/utils/base_ref.cc index e50f0003b8f..aa38c8a6a09 100644 --- a/mindspore/ccsrc/utils/base_ref.cc +++ b/mindspore/ccsrc/utils/base_ref.cc @@ -17,17 +17,17 @@ #include "utils/base_ref.h" namespace mindspore { -iterator ConstIteratorCast(std::vector* v, const const_iterator iter) { +iterator ConstIteratorCast(std::vector *v, const const_iterator iter) { return std::next(v->begin(), std::distance(v->cbegin(), iter)); } -BaseRef::BaseRef(const BaseRef& other) : Base(other), m_ptr(other.m_ptr) { +BaseRef::BaseRef(const BaseRef &other) : Base(other), m_ptr(other.m_ptr) { if (!m_ptr) { m_ptr = other.copy(); } } -bool BaseRef::operator==(const BaseRef& other) const { +bool BaseRef::operator==(const BaseRef &other) const { if (m_ptr == other.m_ptr) { return true; } @@ -55,7 +55,7 @@ bool BaseRef::operator==(const BaseRef& other) const { } // left reference -BaseRef& BaseRef::operator=(const BaseRef& other) { +BaseRef &BaseRef::operator=(const BaseRef &other) { if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { return *this; } @@ -64,7 +64,7 @@ BaseRef& BaseRef::operator=(const BaseRef& other) { } // right reference -BaseRef& BaseRef::operator=(BaseRef&& other) { +BaseRef &BaseRef::operator=(BaseRef &&other) { if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { return *this; } @@ -88,7 +88,7 @@ uint32_t BaseRef::type() const { } // left reference -SetRef& SetRef::operator=(const SetRef& other) { +SetRef &SetRef::operator=(const SetRef &other) { if (elements_ == other.elements_ || this == &other) { return *this; } @@ -100,7 +100,7 @@ std::string SetRef::ToString() const { std::ostringstream buffer; bool begin = true; buffer << "set["; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -113,7 +113,7 @@ std::string SetRef::ToString() const { } // left reference -VectorRef& VectorRef::operator=(const VectorRef& other) { +VectorRef &VectorRef::operator=(const VectorRef &other) { if (elements_ == other.elements_ || this == &other) { return *this; } @@ -125,7 +125,7 @@ std::string VectorRef::ToString() const { std::ostringstream buffer; bool begin = true; buffer << "vector["; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -137,14 +137,14 @@ std::string VectorRef::ToString() const { return buffer.str(); } -bool VectorRef::operator==(const BaseRef& other) const { +bool VectorRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool VectorRef::operator==(const VectorRef& other) const { +bool VectorRef::operator==(const VectorRef &other) const { if (elements_.size() != other.elements_.size()) { return false; } @@ -156,14 +156,14 @@ bool VectorRef::operator==(const VectorRef& other) const { return true; } -bool SetRef::operator==(const BaseRef& other) const { +bool SetRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool SetRef::operator==(const SetRef& other) const { +bool SetRef::operator==(const SetRef &other) const { if (elements_.size() != other.elements_.size()) { return false; } @@ -177,21 +177,21 @@ bool SetRef::operator==(const SetRef& other) const { return true; } -bool RunFunctionRef::operator==(const BaseRef& other) const { +bool RunFunctionRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool RunFunctionRef::operator==(const RunFunctionRef& other) const { return func_ == other.func_; } +bool RunFunctionRef::operator==(const RunFunctionRef &other) const { return func_ == other.func_; } -bool PyObjectRef::operator==(const BaseRef& other) const { +bool PyObjectRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool PyObjectRef::operator==(const PyObjectRef& other) const { return object_ == other.object_; } +bool PyObjectRef::operator==(const PyObjectRef &other) const { return object_ == other.object_; } } // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref.h b/mindspore/ccsrc/utils/base_ref.h index ed00d8280c2..6e7911d0d92 100644 --- a/mindspore/ccsrc/utils/base_ref.h +++ b/mindspore/ccsrc/utils/base_ref.h @@ -40,7 +40,7 @@ using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; using const_reverse_iterator = std::vector::const_reverse_iterator; -using RunFunc = std::function; +using RunFunc = std::function; using RunFuncPtr = std::shared_ptr; template @@ -54,9 +54,9 @@ using is_value = std::is_base_of>; template using is_base_ref = std::is_base_of>; -iterator ConstIteratorCast(std::vector* v, const_iterator iter); +iterator ConstIteratorCast(std::vector *v, const_iterator iter); -inline std::shared_ptr MakeNode(const std::vector& elements) { +inline std::shared_ptr MakeNode(const std::vector &elements) { return std::make_shared(elements); } @@ -68,34 +68,34 @@ inline std::shared_ptr MakeNode(std::initializer_list elemen template >::value && is_base::value, int>::type = 0> -inline BasePtr MakeNode(const T& v) { +inline BasePtr MakeNode(const T &v) { return v; } template >::value && !is_base_ref::value, int>::type = 0> -inline BasePtr MakeNode(const T& v) { +inline BasePtr MakeNode(const T &v) { return MakeValue(v); } -inline std::shared_ptr MakeNode(const VectorRef& a) { return std::make_shared(std::move(a)); } -inline std::shared_ptr MakeNode(const AnfNodePtrList& a) { +inline std::shared_ptr MakeNode(const VectorRef &a) { return std::make_shared(std::move(a)); } +inline std::shared_ptr MakeNode(const AnfNodePtrList &a) { std::vector ret; - (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr& v) { return v; }); + (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; }); return std::make_shared(ret); } -inline std::shared_ptr MakeNode(const SetRef& a) { return std::make_shared(std::move(a)); } -inline std::shared_ptr MakeNode(const RunFuncPtr& a) { return std::make_shared(a); } -inline std::shared_ptr MakeNode(const py::object& a) { return std::make_shared(a); } -inline std::shared_ptr MakeNode(const py::tuple& a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const SetRef &a) { return std::make_shared(std::move(a)); } +inline std::shared_ptr MakeNode(const RunFuncPtr &a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const py::object &a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const py::tuple &a) { return std::make_shared(a); } class BaseRef : public Base { public: BaseRef() : m_ptr(nullptr) {} - BaseRef(const BaseRef& other); + BaseRef(const BaseRef &other); virtual std::shared_ptr copy() const { return m_ptr; } - BaseRef(BaseRef&& other) : Base(other) { + BaseRef(BaseRef &&other) : Base(other) { m_ptr = other.m_ptr; other.m_ptr = nullptr; } @@ -103,7 +103,7 @@ class BaseRef : public Base { // right reference constructor template ::type, BaseRef>::value, T>::type> - BaseRef(T&& t) { // NOLINT + BaseRef(T &&t) { // NOLINT m_ptr = MakeNode(t); } @@ -111,14 +111,14 @@ class BaseRef : public Base { MS_DECLARE_PARENT(BaseRef, Base) - bool operator!=(const BaseRef& other) const { return !(operator==(other)); } + bool operator!=(const BaseRef &other) const { return !(operator==(other)); } - virtual bool operator==(const BaseRef& other) const; + virtual bool operator==(const BaseRef &other) const; // left reference - virtual BaseRef& operator=(const BaseRef& other); + virtual BaseRef &operator=(const BaseRef &other); // right reference - virtual BaseRef& operator=(BaseRef&& other); + virtual BaseRef &operator=(BaseRef &&other); std::size_t hash() const override { if (m_ptr == nullptr) { @@ -139,18 +139,18 @@ class BaseRef : public Base { using BaseRefPtr = std::shared_ptr; struct BaseRefHash { - std::size_t operator()(const BaseRef& c) const { return c.hash(); } + std::size_t operator()(const BaseRef &c) const { return c.hash(); } }; struct BaseRefLess { - bool operator()(const BaseRef& a, const BaseRef& b) const { return a.hash() < b.hash(); } + bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); } }; namespace utils { // judge isa relation // examples: isa(handle), isa(handle) template ::value && !is_base_ref::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (!handle.m_ptr) { return false; } @@ -160,7 +160,7 @@ bool isa(const BaseRef& handle) { // noderef isa ptr isa(x) or isa() template ::value, typename T::element_type>::type, typename std::enable_if::value || is_base_ref::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (handle.m_ptr == nullptr) { return typeid(handle.m_ptr) == typeid(T); } @@ -175,7 +175,7 @@ bool isa(const BaseRef& handle) { // isa(handle) template ::type::element_type> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (handle.m_ptr == nullptr) { return false; } @@ -184,7 +184,7 @@ bool isa(const BaseRef& handle) { // isa(handle), judge reference or ptr template ::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { static const uint32_t tid = Base::GetTypeId(typeid(T).name()); return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa()); } @@ -192,7 +192,7 @@ bool isa(const BaseRef& handle) { // valueref -> C++ type // cast(handle) template ::value && !is_shared_ptr::value, int>::type = 0> -T cast(const BaseRef& handle) { +T cast(const BaseRef &handle) { T ret = GetValue(std::static_pointer_cast(handle.m_ptr)); return std::move(ret); } @@ -200,12 +200,12 @@ T cast(const BaseRef& handle) { // valueref -> valueref type // cast(handle) template ::value, int>::type = 0> -const T& cast(const BaseRef& handle) { +const T &cast(const BaseRef &handle) { if (handle.m_ptr) { - return static_cast(*handle.m_ptr); + return static_cast(*handle.m_ptr); } - return std::move(static_cast(handle)); + return std::move(static_cast(handle)); } // valueref -> nodeptr type @@ -213,7 +213,7 @@ const T& cast(const BaseRef& handle) { template ::value, typename T::element_type>::type, typename std::enable_if::value && std::is_base_of::value, int>::type = 0> -T cast(const BaseRef& handle) { +T cast(const BaseRef &handle) { if (!handle.m_ptr) { MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null"; } @@ -229,11 +229,11 @@ T cast(const BaseRef& handle) { class VectorRef : public BaseRef { public: VectorRef() {} - explicit VectorRef(const std::vector& elements) : elements_(elements) {} - VectorRef(const const_iterator& begin, const const_iterator& end) : elements_(begin, end) {} + explicit VectorRef(const std::vector &elements) : elements_(elements) {} + VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {} // left reference - virtual VectorRef& operator=(const VectorRef& other); + virtual VectorRef &operator=(const VectorRef &other); ~VectorRef() override = default; @@ -244,7 +244,7 @@ class VectorRef : public BaseRef { std::size_t size() const { return elements_.size(); } MS_DECLARE_PARENT(VectorRef, BaseRef) - const BaseRef& operator[](const std::size_t& dim) const { + const BaseRef &operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "Out of the size of the tuple."; } @@ -253,17 +253,17 @@ class VectorRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override; - std::vector& elements() { return elements_; } + std::vector &elements() { return elements_; } void clear() { elements_.clear(); } - bool operator==(const BaseRef& other) const override; - bool operator==(const VectorRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const VectorRef &other) const; - void push_back(const BaseRef& value) { elements_.push_back(value); } - void push_back(BaseRef&& value) { elements_.push_back(value); } + void push_back(const BaseRef &value) { elements_.push_back(value); } + void push_back(BaseRef &&value) { elements_.push_back(value); } - void emplace_back(const BaseRef& value) { elements_.emplace_back(value); } - void emplace_back(BaseRef&& value) { elements_.emplace_back(value); } + void emplace_back(const BaseRef &value) { elements_.emplace_back(value); } + void emplace_back(BaseRef &&value) { elements_.emplace_back(value); } template void insert(const iterator pos, const InputIt first, const InputIt last) { @@ -308,21 +308,21 @@ using set_iterator = std::set::iterator; using const_set_iterator = std::set::const_iterator; struct VectorRefHash { - std::size_t operator()(const VectorRef& c) const { return c.hash(); } + std::size_t operator()(const VectorRef &c) const { return c.hash(); } }; class SetRef : public BaseRef { public: SetRef() {} - explicit SetRef(const std::set& elements) : elements_(elements) {} + explicit SetRef(const std::set &elements) : elements_(elements) {} SetRef(const std::initializer_list elements) : elements_(elements.begin(), elements.end()) {} - SetRef(const const_set_iterator& begin, const const_set_iterator& end) : elements_(begin, end) {} + SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {} // left reference - virtual SetRef& operator=(const SetRef& other); + virtual SetRef &operator=(const SetRef &other); - bool operator==(const BaseRef& other) const override; - bool operator==(const SetRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const SetRef &other) const; ~SetRef() override = default; @@ -335,10 +335,10 @@ class SetRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override; - std::set& elements() { return elements_; } + std::set &elements() { return elements_; } void clear() { elements_.clear(); } - void insert(const BaseRef& elem) { (void)elements_.insert(elem); } + void insert(const BaseRef &elem) { (void)elements_.insert(elem); } const_set_iterator begin() const { return elements_.begin(); } const_set_iterator end() const { return elements_.end(); } @@ -348,8 +348,8 @@ class SetRef : public BaseRef { (void)elements_.insert(first, last); } - std::size_t count(const BaseRef& elem) const { return elements_.count(elem); } - const_set_iterator find(const BaseRef& elem) const { return elements_.find(elem); } + std::size_t count(const BaseRef &elem) const { return elements_.count(elem); } + const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); } std::set elements_; }; @@ -358,8 +358,8 @@ using SetRefPtr = std::shared_ptr; class PyObjectRef : public BaseRef { public: - explicit PyObjectRef(const py::object& py_object) : object_(py_object) {} - explicit PyObjectRef(const py::tuple& tuple_obj) : object_(tuple_obj) {} + explicit PyObjectRef(const py::object &py_object) : object_(py_object) {} + explicit PyObjectRef(const py::tuple &tuple_obj) : object_(tuple_obj) {} ~PyObjectRef() override = default; @@ -368,8 +368,8 @@ class PyObjectRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override { return py::str(object_); } - bool operator==(const BaseRef& other) const override; - bool operator==(const PyObjectRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const PyObjectRef &other) const; py::object object_; }; @@ -377,15 +377,15 @@ class PyObjectRef : public BaseRef { class RunFunctionRef : public BaseRef { public: RunFunctionRef() {} - explicit RunFunctionRef(const RunFuncPtr& ref_func) : func_(ref_func) {} + explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {} ~RunFunctionRef() override = default; MS_DECLARE_PARENT(RunFunctionRef, BaseRef) uint32_t type() const override { return tid(); } std::string ToString() const override { return std::string("RunFunctionRef"); } - bool operator==(const BaseRef& other) const override; - bool operator==(const RunFunctionRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const RunFunctionRef &other) const; RunFuncPtr func_; }; diff --git a/mindspore/ccsrc/utils/callbacks.cc b/mindspore/ccsrc/utils/callbacks.cc index 03c6322afe4..06bf1c73ab7 100644 --- a/mindspore/ccsrc/utils/callbacks.cc +++ b/mindspore/ccsrc/utils/callbacks.cc @@ -37,14 +37,14 @@ const int ONE_SHAPE = 1; // Cache the summary callback data from ME session // Remove the GE module on new architecture // Output Format: [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...] -uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; py::list summary_list = py::list(); MS_LOG(INFO) << "The Summary save callback function for graph " << graph_id << ", Param list size = " << params_list.size() << "."; - for (auto& item : params_list) { + for (auto &item : params_list) { std::string tag_name = item.first; auto tensor_ptr = item.second; if (tensor_ptr == nullptr) { diff --git a/mindspore/ccsrc/utils/callbacks.h b/mindspore/ccsrc/utils/callbacks.h index a1e4e75d5b6..9f46df0414c 100644 --- a/mindspore/ccsrc/utils/callbacks.h +++ b/mindspore/ccsrc/utils/callbacks.h @@ -39,9 +39,9 @@ extern const std::string kPythonCheckpointFuncName; const int kCallbackOk = 0; const int kCallbackFalied = 1; -bool GetParameterShape(const FuncGraphPtr& anf_graph, const std::string& param_name, - const std::shared_ptr>& shape); -uint32_t SummarySaveCallback(uint32_t, const std::map&); +bool GetParameterShape(const FuncGraphPtr &anf_graph, const std::string ¶m_name, + const std::shared_ptr> &shape); +uint32_t SummarySaveCallback(uint32_t, const std::map &); } // namespace callbacks } // namespace mindspore diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index 36bbcbf297b..b4c9fda6340 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -35,15 +35,15 @@ const int ONE_SHAPE = 1; using mindspore::transform::Status; using mindspore::transform::TransformUtil; -bool GetParameterShape(const FuncGraphPtr& graph, const std::string& param_name, - const std::shared_ptr>& shape) { +bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name, + const std::shared_ptr> &shape) { if (graph == nullptr) { MS_LOG(ERROR) << "Graph is null, can not get graph parameter"; return false; } auto parameter_nodes = graph->parameters(); - for (auto& node : parameter_nodes) { + for (auto &node : parameter_nodes) { ParameterPtr param_node = std::static_pointer_cast(node); if (param_node == nullptr) { MS_LOG(ERROR) << "Parameter node is null, can not get graph parameter"; @@ -65,8 +65,8 @@ bool GetParameterShape(const FuncGraphPtr& graph, const std::string& param_name, return false; } -static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string& parameter_name, - const std::shared_ptr& ge_tensor_ptr) { +static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string ¶meter_name, + const std::shared_ptr &ge_tensor_ptr) { FuncGraphPtr anf_graph = transform::DfGraphManager::GetInstance().GetAnfGraph(graph_id); if (anf_graph == nullptr) { MS_LOG(ERROR) << "Get anf graph failed during callback"; @@ -82,13 +82,13 @@ static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string& pa return TransformUtil::ConvertGeTensor(ge_tensor_ptr, *parameter_shape_ptr); } -uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; MS_LOG(DEBUG) << "Start the checkpoint save callback function in checkpoint save process."; py::list parameter_list = py::list(); - for (auto& item : params_list) { + for (auto &item : params_list) { std::string name = item.first; std::shared_ptr ge_tensor_ptr = std::make_shared(item.second); TensorPtr tensor_ptr = GetMeTensorTransformed(graph_id, name, ge_tensor_ptr); @@ -112,7 +112,7 @@ uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map& ge_tensor_ptr) { +static TensorPtr GetMeTensorForSummary(const std::string &name, const std::shared_ptr &ge_tensor_ptr) { // confirm the type by name // Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor] if (name.empty()) { @@ -149,14 +149,14 @@ static TensorPtr GetMeTensorForSummary(const std::string& name, const std::share // Cache the summary callback data // Output Format: [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...] -uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; MS_LOG(DEBUG) << "Start the summary save callback function for graph " << graph_id << "."; py::list summary_list = py::list(); MS_LOG(DEBUG) << "Param list size = " << params_list.size(); - for (auto& item : params_list) { + for (auto &item : params_list) { std::string tag_name = item.first; std::shared_ptr ge_tensor_ptr = std::make_shared(item.second); TensorPtr tensor_ptr = GetMeTensorForSummary(tag_name, ge_tensor_ptr); diff --git a/mindspore/ccsrc/utils/callbacks_ge.h b/mindspore/ccsrc/utils/callbacks_ge.h index 750ec746665..08f5bb59dbf 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.h +++ b/mindspore/ccsrc/utils/callbacks_ge.h @@ -29,8 +29,8 @@ namespace callbacks { using mindspore::tensor::TensorPtr; -uint32_t CheckpointSaveCallback(uint32_t, const std::map&); -uint32_t SummarySaveCallback(uint32_t, const std::map&); +uint32_t CheckpointSaveCallback(uint32_t, const std::map &); +uint32_t SummarySaveCallback(uint32_t, const std::map &); } // namespace callbacks } // namespace mindspore diff --git a/mindspore/ccsrc/utils/config_manager.cc b/mindspore/ccsrc/utils/config_manager.cc index 6d66b37436c..7dc559b20e8 100644 --- a/mindspore/ccsrc/utils/config_manager.cc +++ b/mindspore/ccsrc/utils/config_manager.cc @@ -22,12 +22,12 @@ namespace mindspore { -ConfigManager& ConfigManager::GetInstance() noexcept { +ConfigManager &ConfigManager::GetInstance() noexcept { static ConfigManager instance; return instance; } -void ConfigManager::SetDatasetModeConfig(const std::string& mode) { +void ConfigManager::SetDatasetModeConfig(const std::string &mode) { static const std::map mode_map = {{"normal", DS_NORMAL_MODE}, {"sink", DS_SINK_MODE}}; if (mode_map.find(mode) == mode_map.end()) { MS_LOG(ERROR) << "Invalid dataset mode:" << mode; diff --git a/mindspore/ccsrc/utils/config_manager.h b/mindspore/ccsrc/utils/config_manager.h index db7d7d0c14d..635f24792aa 100644 --- a/mindspore/ccsrc/utils/config_manager.h +++ b/mindspore/ccsrc/utils/config_manager.h @@ -37,8 +37,8 @@ enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE }; class DatasetGraphParam { public: - DatasetGraphParam(const std::string& name, int64_t size, int64_t batch_size, const std::vector& ge_types, - const std::vector>& shapes, const std::vector& input_indexes) + DatasetGraphParam(const std::string &name, int64_t size, int64_t batch_size, const std::vector &ge_types, + const std::vector> &shapes, const std::vector &input_indexes) : queue_name_(name), loop_size_(size), batch_size_(batch_size), @@ -72,15 +72,15 @@ class DatasetGraphParam { class ConfigManager { public: - ConfigManager(const ConfigManager&) = delete; - ConfigManager& operator=(const ConfigManager&) = delete; - static ConfigManager& GetInstance() noexcept; + ConfigManager(const ConfigManager &) = delete; + ConfigManager &operator=(const ConfigManager &) = delete; + static ConfigManager &GetInstance() noexcept; ParallelStrategy parallel_strategy() const { return parallel_strategy_; } void set_parallel_strategy(ParallelStrategy strategy) { parallel_strategy_ = strategy; } - const std::map& ge_initialize_options() const { return ge_initialize_options_; } - void set_ge_initialize_options(const std::map& options) { + const std::map &ge_initialize_options() const { return ge_initialize_options_; } + void set_ge_initialize_options(const std::map &options) { ge_initialize_options_ = options; } @@ -90,12 +90,12 @@ class ConfigManager { void set_iter_num(const int64_t num) { iter_num_ = num; } std::string dataset_phase() const { return dataset_phase_; } - void set_dataset_phase(const std::string& phase) { dataset_phase_ = phase; } + void set_dataset_phase(const std::string &phase) { dataset_phase_ = phase; } DatasetGraphParam dataset_param() const { return dataset_param_; } - void set_dataset_param(const DatasetGraphParam& param) { dataset_param_ = param; } + void set_dataset_param(const DatasetGraphParam ¶m) { dataset_param_ = param; } - static void SetDatasetModeConfig(const std::string& mode); + static void SetDatasetModeConfig(const std::string &mode); void ResetConfig() noexcept; diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index bee5875f603..0a2f065140e 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -45,7 +45,7 @@ std::map MsContext::policy_map_ = {{"ge", kMsBacke {"ge_only", kMsBackendGeOnly}, {"vm_prior", kMsBackendVmPrior}}; -MsContext::MsContext(const std::string& policy, const std::string& target) { +MsContext::MsContext(const std::string &policy, const std::string &target) { save_graphs_flag_ = false; save_graphs_path_ = "."; save_ms_model_flag_ = false; @@ -97,7 +97,7 @@ std::shared_ptr MsContext::GetInstance() { return inst_context_; } -bool MsContext::set_backend_policy(const std::string& policy) { +bool MsContext::set_backend_policy(const std::string &policy) { if (policy_map_.find(policy) == policy_map_.end()) { MS_LOG(ERROR) << "invalid backend policy name: " << policy; return false; @@ -110,7 +110,7 @@ bool MsContext::set_backend_policy(const std::string& policy) { std::string MsContext::backend_policy() const { auto res = std::find_if( policy_map_.begin(), policy_map_.end(), - [&, this](const std::pair& item) { return item.second == backend_policy_; }); + [&, this](const std::pair &item) { return item.second == backend_policy_; }); if (res != policy_map_.end()) { return res->first; } @@ -124,7 +124,7 @@ void MsContext::set_execution_mode(int execution_mode) { execution_mode_ = execution_mode; } -bool MsContext::set_device_target(const std::string& target) { +bool MsContext::set_device_target(const std::string &target) { if (kTargetSet.find(target) == kTargetSet.end()) { MS_LOG(ERROR) << "invalid device target name: " << target; return false; @@ -218,7 +218,7 @@ bool MsContext::CloseTsd(bool force) { MS_LOG(INFO) << "join tdt host receive process"; tdt_print_.join(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); } #endif @@ -241,7 +241,7 @@ bool MsContext::OpenTsd() { return true; } bool MsContext::CloseTsd(bool) { return true; } #endif -void MsContext::SetHcclOptions(std::map* ge_options) const { +void MsContext::SetHcclOptions(std::map *ge_options) const { auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); auto env_rank_id = common::GetEnv("RANK_ID"); auto env_device_id = std::to_string(device_id_); @@ -274,7 +274,7 @@ void MsContext::SetHcclOptions(std::map* ge_options) c } } -void MsContext::GetGeOptions(std::map* ge_options) const { +void MsContext::GetGeOptions(std::map *ge_options) const { #ifdef ENABLE_GE (*ge_options)["device_id"] = "0"; (*ge_options)["ge.exec.enableDump"] = std::to_string(enable_dump_); @@ -365,7 +365,7 @@ void MsContext::GetGeOptions(std::map* ge_options) con #endif } -void MsContext::SetDisableReuseMemoryFlag(std::map* ge_options) const { +void MsContext::SetDisableReuseMemoryFlag(std::map *ge_options) const { auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY"); if (!env_disable_reuse_memory.empty()) { (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory; @@ -412,7 +412,7 @@ bool MsContext::FinalizeGe(bool force) { try { DfGraphManager::GetInstance().DeleteGraphRunner(); DfGraphManager::GetInstance().DeleteGeSession(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what(); } catch (...) { std::string exName(abi::__cxa_current_exception_type()->name()); diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 06704ff9c6e..1d84061a8a1 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -48,13 +48,13 @@ const std::set kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, class MsContext { public: ~MsContext() = default; - MsContext(const MsContext&) = delete; - MsContext& operator=(const MsContext&) = delete; + MsContext(const MsContext &) = delete; + MsContext &operator=(const MsContext &) = delete; static std::shared_ptr GetInstance(); std::string backend_policy() const; - bool set_backend_policy(const std::string& policy); + bool set_backend_policy(const std::string &policy); int execution_mode() const { return execution_mode_; } void set_execution_mode(int execution_mode); @@ -69,7 +69,7 @@ class MsContext { bool precompile_only() const { return precompile_only_; } std::string device_target() const { return device_target_; } - bool set_device_target(const std::string& target); + bool set_device_target(const std::string &target); uint32_t device_id() const { return device_id_; } bool set_device_id(uint32_t device_id); @@ -78,7 +78,7 @@ class MsContext { void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } std::string save_graphs_path() const { return save_graphs_path_; } - void set_save_graphs_path(const std::string& save_paths) { save_graphs_path_ = save_paths; } + void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; } bool OpenTsd(); bool CloseTsd(bool force = false); @@ -101,7 +101,7 @@ class MsContext { void set_save_ms_model_flag(bool save_ms_model_flag) { save_ms_model_flag_ = save_ms_model_flag; } std::string save_ms_model_path() const { return save_ms_model_path_; } - void set_save_ms_model_path(const std::string& save_ms_model_path) { save_ms_model_path_ = save_ms_model_path; } + void set_save_ms_model_path(const std::string &save_ms_model_path) { save_ms_model_path_ = save_ms_model_path; } void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } bool enable_gpu_summary() const { return enable_gpu_summary_; } @@ -117,7 +117,7 @@ class MsContext { void set_enable_dump(bool flag) { enable_dump_ = flag; } bool enable_dump() const { return enable_dump_; } - void set_save_dump_path(const std::string& path) { save_dump_path_ = path; } + void set_save_dump_path(const std::string &path) { save_dump_path_ = path; } std::string save_dump_path() const { return save_dump_path_; } bool IsTsdOpened() const { return tsd_ref_ > 0; } @@ -128,19 +128,19 @@ class MsContext { void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; } bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; } - void set_graph_memory_max_size(const std::string& graph_memory_max_size) { + void set_graph_memory_max_size(const std::string &graph_memory_max_size) { graph_memory_max_size_ = graph_memory_max_size; } - void set_variable_memory_max_size(const std::string& variable_memory_max_size) { + void set_variable_memory_max_size(const std::string &variable_memory_max_size) { variable_memory_max_size_ = variable_memory_max_size; } private: - MsContext(const std::string& backend_policy, const std::string& target); - void GetGeOptions(std::map* ge_options) const; - void SetDisableReuseMemoryFlag(std::map* ge_options) const; - void SetHcclOptions(std::map* ge_options) const; + MsContext(const std::string &backend_policy, const std::string &target); + void GetGeOptions(std::map *ge_options) const; + void SetDisableReuseMemoryFlag(std::map *ge_options) const; + void SetHcclOptions(std::map *ge_options) const; static std::shared_ptr inst_context_; static std::map policy_map_; diff --git a/mindspore/ccsrc/utils/counter.h b/mindspore/ccsrc/utils/counter.h index 891f9c7a35a..ead0ad84f20 100644 --- a/mindspore/ccsrc/utils/counter.h +++ b/mindspore/ccsrc/utils/counter.h @@ -29,17 +29,17 @@ class Counter { Counter() = default; ~Counter() = default; - Counter(const Counter& other) { data = other.data; } - Counter& operator=(const Counter& other) { + Counter(const Counter &other) { data = other.data; } + Counter &operator=(const Counter &other) { if (this != &other) { data = other.data; } return *this; } - int& operator[](const T& t) { return data[t]; } + int &operator[](const T &t) { return data[t]; } - counter_type operator-(const counter_type& other) { + counter_type operator-(const counter_type &other) { counter_type new_counter; for (auto iter = begin(); iter != end(); ++iter) { auto key = iter->first; @@ -58,7 +58,7 @@ class Counter { return new_counter; } - counter_type operator+(const counter_type& other) { + counter_type operator+(const counter_type &other) { counter_type new_counter; for (auto iter = begin(); iter != end(); ++iter) { auto key = iter->first; @@ -84,7 +84,7 @@ class Counter { std::size_t size() const { return data.size(); } - bool contains(const T& t) const { return data.find(t) != data.end(); } + bool contains(const T &t) const { return data.find(t) != data.end(); } typename OrderedMap::iterator begin() { return data.begin(); } diff --git a/mindspore/ccsrc/utils/graph_utils.cc b/mindspore/ccsrc/utils/graph_utils.cc index 55ef8dc3d5a..08016225495 100644 --- a/mindspore/ccsrc/utils/graph_utils.cc +++ b/mindspore/ccsrc/utils/graph_utils.cc @@ -39,10 +39,10 @@ using SymbolicKeyTypePtr = std::shared_ptr; namespace { class DeepFirstSearcher : public AnfVisitor { public: - explicit DeepFirstSearcher(const IncludeFunc& include) : include_(include) {} + explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {} ~DeepFirstSearcher() override = default; - std::vector Search(const AnfNodePtr& root) { + std::vector Search(const AnfNodePtr &root) { if (root == nullptr) { return res_; } @@ -50,7 +50,7 @@ class DeepFirstSearcher : public AnfVisitor { return res_; } - void Visit(const AnfNodePtr& node) override { + void Visit(const AnfNodePtr &node) override { MS_EXCEPTION_IF_NULL(node); if (seen_.count(node) != 0) { return; @@ -77,10 +77,10 @@ class DeepFirstSearcher : public AnfVisitor { class DeepScopedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepScopedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepScopedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { + void Visit(const CNodePtr &cnode) override { if (cnode->func_graph() == nullptr) { return; } @@ -90,13 +90,13 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { DeepFirstSearcher::Visit(ret); } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr& vnode) override { + void Visit(const ValueNodePtr &vnode) override { if (!IsValueNode(vnode)) { return; } @@ -108,7 +108,7 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { } } - void Visit(const ParameterPtr& param) override { + void Visit(const ParameterPtr ¶m) override { if (param->func_graph() == nullptr) { return; } @@ -122,17 +122,17 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { class DeepUsedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepUsedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepUsedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { - auto& inputs = cnode->inputs(); + void Visit(const CNodePtr &cnode) override { + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr& vnode) override { + void Visit(const ValueNodePtr &vnode) override { if (!IsValueNode(vnode)) { return; } @@ -147,33 +147,33 @@ class DeepUsedGraphSearcher : public DeepFirstSearcher { class DeepLinkedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepLinkedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepLinkedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { - auto& inputs = cnode->inputs(); + void Visit(const CNodePtr &cnode) override { + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr&) override {} + void Visit(const ValueNodePtr &) override {} }; } // namespace -std::vector DeepScopedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepScopedGraphSearcher(include).Search(root); } -std::vector DeepUsedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepUsedGraphSearcher(include).Search(root); } -std::vector DeepLinkedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepLinkedGraphSearcher(include).Search(root); } -std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ, const IncludeFunc& include) { +std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { std::unordered_set done; std::list todo(1, root); std::unordered_map rank; @@ -222,7 +222,7 @@ std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ, c return res; } -std::vector SuccDeeper(const AnfNodePtr& node) { +std::vector SuccDeeper(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; @@ -237,7 +237,7 @@ std::vector SuccDeeper(const AnfNodePtr& node) { return vecs; } else if (node->func_graph() != nullptr) { if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } auto graph = node->func_graph(); @@ -250,7 +250,7 @@ std::vector SuccDeeper(const AnfNodePtr& node) { return vecs; } -std::vector SuccDeeperSimple(const AnfNodePtr& node) { +std::vector SuccDeeperSimple(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; @@ -265,39 +265,39 @@ std::vector SuccDeeperSimple(const AnfNodePtr& node) { return vecs; } else { if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } return vecs; } } -std::vector SuccIncoming(const AnfNodePtr& node) { +std::vector SuccIncoming(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; } if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } return vecs; } -std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& node) { +std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; } if (node->isa()) { auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); // Check if free variables used. - for (const auto& input : inputs) { + for (const auto &input : inputs) { auto input_fg = GetValueNode(input); if (input_fg) { - for (auto& fv : input_fg->free_variables_nodes()) { + for (auto &fv : input_fg->free_variables_nodes()) { if (fv->func_graph() == fg && fg->nodes().contains(fv)) { vecs.push_back(fv); } @@ -309,9 +309,9 @@ std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& return vecs; } -IncludeType AlwaysInclude(const AnfNodePtr&) { return FOLLOW; } +IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; } -IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node) { +IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) { if (node->func_graph() == fg) { return FOLLOW; } else { @@ -319,12 +319,12 @@ IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node) { } } -FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search, const IncludeFunc& include) { +FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) { MS_EXCEPTION_IF_NULL(fg); Acquire(fg); auto vec = search(fg->get_return(), include); - for (auto& node : vec) { + for (auto &node : vec) { MS_EXCEPTION_IF_NULL(node); Acquire(node); if (node->func_graph() != nullptr) { @@ -333,7 +333,7 @@ FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search, } } -std::set FuncGraphIndex::GetFuncGraphs(const std::string& key) { +std::set FuncGraphIndex::GetFuncGraphs(const std::string &key) { std::set func_graphs; if (index_func_graph_.find(key) != index_func_graph_.end()) { func_graphs = index_func_graph_[key]; @@ -341,7 +341,7 @@ std::set FuncGraphIndex::GetFuncGraphs(const std::string& key) { return func_graphs; } -std::set FuncGraphIndex::GetNodes(const std::string& key) { +std::set FuncGraphIndex::GetNodes(const std::string &key) { if (index_node_.find(key) != index_node_.end()) { return index_node_[key]; } @@ -349,7 +349,7 @@ std::set FuncGraphIndex::GetNodes(const std::string& key) { return std::set(); } -FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string& key) { +FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) { if (GetFuncGraphs(key).empty()) { return nullptr; } @@ -358,7 +358,7 @@ FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string& key) { return fg; } -AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string& key) { +AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) { if (GetNodes(key).empty()) { return nullptr; } @@ -367,14 +367,14 @@ AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string& key) { return node; } -void FuncGraphIndex::Acquire(const FuncGraphPtr& key) { +void FuncGraphIndex::Acquire(const FuncGraphPtr &key) { std::string name = label_manage::Label(key->debug_info()); if (!name.empty()) { (void)index_func_graph_[name].insert(key); } } -void FuncGraphIndex::Acquire(const AnfNodePtr& key) { +void FuncGraphIndex::Acquire(const AnfNodePtr &key) { std::string name = label_manage::Label(key->debug_info()); if (!name.empty()) { (void)index_node_[name].insert(key); @@ -382,8 +382,8 @@ void FuncGraphIndex::Acquire(const AnfNodePtr& key) { } // Isomorphism -static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { if (equiv_node == nullptr) { MS_LOG(ERROR) << "Invalid equiv_node"; return false; @@ -419,13 +419,13 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu return false; } -static bool SameNode(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { MS_EXCEPTION_IF_NULL(node1); MS_EXCEPTION_IF_NULL(node2); if (node1->isa() && node2->isa()) { - auto& inputs1 = node1->cast()->inputs(); - auto& inputs2 = node2->cast()->inputs(); + auto &inputs1 = node1->cast()->inputs(); + auto &inputs2 = node2->cast()->inputs(); for (std::size_t i = 0; i < inputs1.size(); ++i) { if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) { return false; @@ -436,8 +436,8 @@ static bool SameNode(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraph return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node); } -static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { std::unordered_set done; std::stack> todo; @@ -479,8 +479,8 @@ static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEqu return true; } -bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { auto fg1_fg2 = std::make_pair(fg1, fg2); if (equiv_func_graph == nullptr) { MS_LOG(ERROR) << "equiv_func_graph not init"; @@ -511,7 +511,7 @@ bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv* equiv return false; } -tensor::TensorPtr ScalarToTensor(const ScalarPtr& scalar) { +tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { if (scalar == nullptr) { MS_EXCEPTION(ArgumentError) << "Nullptr Error!"; } diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h index 57bc0e42fcf..d01335af829 100644 --- a/mindspore/ccsrc/utils/graph_utils.h +++ b/mindspore/ccsrc/utils/graph_utils.h @@ -38,42 +38,42 @@ namespace mindspore { enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; -using IncludeFunc = std::function; +using IncludeFunc = std::function; using SuccFunc = std::function(AnfNodePtr)>; -using SearchFunc = std::function(const AnfNodePtr&, const IncludeFunc&)>; +using SearchFunc = std::function(const AnfNodePtr &, const IncludeFunc &)>; -std::vector SuccDeeper(const AnfNodePtr& node); -std::vector SuccDeeperSimple(const AnfNodePtr& node); -std::vector SuccIncoming(const AnfNodePtr& node); -std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& node); +std::vector SuccDeeper(const AnfNodePtr &node); +std::vector SuccDeeperSimple(const AnfNodePtr &node); +std::vector SuccIncoming(const AnfNodePtr &node); +std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node); -IncludeType AlwaysInclude(const AnfNodePtr& node); -IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node); +IncludeType AlwaysInclude(const AnfNodePtr &node); +IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); -std::vector DeepScopedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); -std::vector DeepUsedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); -std::vector DeepLinkedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); -std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ = SuccIncoming, - const IncludeFunc& include = AlwaysInclude); +std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, + const IncludeFunc &include = AlwaysInclude); class FuncGraphIndex { public: - explicit FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search = DeepScopedGraphSearch, - const IncludeFunc& include = AlwaysInclude); - FuncGraphIndex(const FuncGraphIndex&) = delete; - FuncGraphIndex& operator=(const FuncGraphIndex&) = delete; + explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, + const IncludeFunc &include = AlwaysInclude); + FuncGraphIndex(const FuncGraphIndex &) = delete; + FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; virtual ~FuncGraphIndex() {} - std::set GetFuncGraphs(const std::string& key); - std::set GetNodes(const std::string& key); - FuncGraphPtr GetFirstFuncGraph(const std::string& key); - AnfNodePtr GetFirstNode(const std::string& key); + std::set GetFuncGraphs(const std::string &key); + std::set GetNodes(const std::string &key); + FuncGraphPtr GetFirstFuncGraph(const std::string &key); + AnfNodePtr GetFirstNode(const std::string &key); private: - void Acquire(const FuncGraphPtr& key); - void Acquire(const AnfNodePtr& key); + void Acquire(const FuncGraphPtr &key); + void Acquire(const AnfNodePtr &key); std::map> index_func_graph_; std::map> index_node_; @@ -83,7 +83,7 @@ class FuncGraphIndex { struct PairHasher { template - std::size_t operator()(const std::pair& p) const { + std::size_t operator()(const std::pair &p) const { auto h1 = std::hash{}(p.first); auto h2 = std::hash{}(p.second); return h1 ^ h2; @@ -95,9 +95,9 @@ enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 }; using FuncGraphPairMapEquiv = std::unordered_map, EquivState, PairHasher>; using NodeMapEquiv = std::unordered_map; -bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv* equiv_func_graph, NodeMapEquiv* equiv_node); +bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node); -tensor::TensorPtr ScalarToTensor(const ScalarPtr& scalar); +tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_ diff --git a/mindspore/ccsrc/utils/hashing.h b/mindspore/ccsrc/utils/hashing.h index 730657ce7a9..cc8cc5b9914 100644 --- a/mindspore/ccsrc/utils/hashing.h +++ b/mindspore/ccsrc/utils/hashing.h @@ -25,7 +25,7 @@ inline std::size_t hash_combine(std::size_t hash_sum, std::size_t hash_val) { return ((hash_sum << 6) + (hash_sum >> 2) + 0x9e3779b9 + hash_val) ^ hash_sum; } -inline std::size_t hash_combine(const std::initializer_list& hash_vals) { +inline std::size_t hash_combine(const std::initializer_list &hash_vals) { std::size_t hash_sum = 0; for (auto hash_val : hash_vals) { hash_sum = hash_combine(hash_sum, hash_val); diff --git a/mindspore/ccsrc/utils/misc.cc b/mindspore/ccsrc/utils/misc.cc index 47e675a3413..a9eb8071ef8 100644 --- a/mindspore/ccsrc/utils/misc.cc +++ b/mindspore/ccsrc/utils/misc.cc @@ -23,9 +23,9 @@ const int RET_FAILED = 1; const int RET_CONTINUE = 2; const int RET_BREAK = 3; -std::string demangle(const char* name) { +std::string demangle(const char *name) { int status = -1; - std::unique_ptr res{abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free}; + std::unique_ptr res{abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free}; return (status == 0) ? res.get() : name; } } // namespace mindspore diff --git a/mindspore/ccsrc/utils/misc.h b/mindspore/ccsrc/utils/misc.h index 66e8937f9cc..e2cdebe98ad 100644 --- a/mindspore/ccsrc/utils/misc.h +++ b/mindspore/ccsrc/utils/misc.h @@ -33,7 +33,7 @@ extern const int RET_CONTINUE; extern const int RET_BREAK; // demangle the name to make it human reablable. -extern std::string demangle(const char* name); +extern std::string demangle(const char *name); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_MISC_H_ diff --git a/mindspore/ccsrc/utils/ordered_set.h b/mindspore/ccsrc/utils/ordered_set.h index b22053f196c..f393ce74f2e 100644 --- a/mindspore/ccsrc/utils/ordered_set.h +++ b/mindspore/ccsrc/utils/ordered_set.h @@ -53,28 +53,28 @@ class OrderedSet { // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion, // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use // traversal to build elements. - OrderedSet(const OrderedSet& os) { - for (auto& item : os.ordered_data_) { + OrderedSet(const OrderedSet &os) { + for (auto &item : os.ordered_data_) { add(item); } } - explicit OrderedSet(const sequential_type& other) { - for (auto& item : other) { + explicit OrderedSet(const sequential_type &other) { + for (auto &item : other) { add(item); } } // Explicitly construct an OrderedSet use vector - explicit OrderedSet(const vector_type& other) { - for (auto& item : other) { + explicit OrderedSet(const vector_type &other) { + for (auto &item : other) { add(item); } } - OrderedSet& operator=(const OrderedSet& os) { + OrderedSet &operator=(const OrderedSet &os) { if (this != &os) { - for (auto& item : os.ordered_data_) { + for (auto &item : os.ordered_data_) { add(item); } } @@ -82,14 +82,14 @@ class OrderedSet { } // Add an element to the OrderedSet, without judging return value - void add(const element_type& e) { (void)insert(e); } + void add(const element_type &e) { (void)insert(e); } // insert an element to the OrderedSet - std::pair insert(const element_type& e) { + std::pair insert(const element_type &e) { iterator empty_itr; std::pair map_pair = std::make_pair(e, empty_itr); auto result = mapped_data_.insert(map_pair); - auto& seq_idx = result.first->second; + auto &seq_idx = result.first->second; // if insert success; if (result.second) { auto it = ordered_data_.insert(ordered_data_.end(), e); @@ -99,7 +99,7 @@ class OrderedSet { } // Remove an element, if removed return true, otherwise return false - bool erase(const element_type& e) { + bool erase(const element_type &e) { auto pos = mapped_data_.find(e); if (pos == mapped_data_.end()) { return false; @@ -119,7 +119,7 @@ class OrderedSet { std::string toString() { std::ostringstream res; res << "orderset content:\n"; - for (auto& item : ordered_data_) { + for (auto &item : ordered_data_) { res << std::to_string(reinterpret_cast(item.get())) << " "; } return res.str(); @@ -132,7 +132,7 @@ class OrderedSet { } // Compare two orderedset, if the order is not equal shall return false - bool operator==(const OrderedSet& other) const { return ordered_data_ == other.ordered_data_; } + bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } // Remove and return the first element in the OrderedSet T pop() { @@ -153,8 +153,8 @@ class OrderedSet { } // Return true if there are no common elements - bool is_disjoint(const OrderedSet& other) { - for (auto& item : other.ordered_data_) { + bool is_disjoint(const OrderedSet &other) { + for (auto &item : other.ordered_data_) { if (mapped_data_.find(item) != mapped_data_.end()) { return false; } @@ -163,8 +163,8 @@ class OrderedSet { } // Test whether this is subset of other - bool is_subset(const OrderedSet& other) { - for (auto& item : ordered_data_) { + bool is_subset(const OrderedSet &other) { + for (auto &item : ordered_data_) { if (other.mapped_data_.find(item) == other.mapped_data_.end()) { return false; } @@ -173,51 +173,51 @@ class OrderedSet { } // Add elements in other to this orderedset - void update(const OrderedSet& other) { - for (auto& item : other.ordered_data_) { + void update(const OrderedSet &other) { + for (auto &item : other.ordered_data_) { add(item); } } - void update(const std::shared_ptr& other) { update(*other); } + void update(const std::shared_ptr &other) { update(*other); } - void update(const sequential_type& other) { - for (auto& item : other) { + void update(const sequential_type &other) { + for (auto &item : other) { add(item); } } - void update(const vector_type& other) { - for (auto& item : other) { + void update(const vector_type &other) { + for (auto &item : other) { add(item); } } - ordered_set_type get_union(const OrderedSet& other) { + ordered_set_type get_union(const OrderedSet &other) { ordered_set_type res(ordered_data_); res.update(other); return res; } // Get the union with other set, this operator may cost time because of copy - ordered_set_type operator|(const OrderedSet& other) { return get_union(other); } + ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } // Return the intersection of two sets - ordered_set_type intersection(const OrderedSet& other) { + ordered_set_type intersection(const OrderedSet &other) { ordered_set_type res(ordered_data_); - for (auto& item : ordered_data_) { + for (auto &item : ordered_data_) { if (other.mapped_data_.find(item) == other.mapped_data_.end()) { (void)res.erase(item); } } return res; } - ordered_set_type operator&(const OrderedSet& other) { return intersection(other); } + ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } // Return the symmetric difference of two sets - ordered_set_type symmetric_difference(const OrderedSet& other) { + ordered_set_type symmetric_difference(const OrderedSet &other) { ordered_set_type res(ordered_data_); - for (auto& item : other.ordered_data_) { + for (auto &item : other.ordered_data_) { if (mapped_data_.find(item) != mapped_data_.end()) { (void)res.erase(item); } else { @@ -227,40 +227,40 @@ class OrderedSet { return res; } - ordered_set_type operator^(const OrderedSet& other) { return symmetric_difference(other); } + ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } // Remove elements which is also in others. - void difference_update(const OrderedSet& other) { + void difference_update(const OrderedSet &other) { // use vector traversal, to keep ordrer - for (auto& item : other.ordered_data_) { + for (auto &item : other.ordered_data_) { (void)erase(item); } } - void difference_update(const sequential_type& other) { - for (auto& item : other) { + void difference_update(const sequential_type &other) { + for (auto &item : other) { (void)erase(item); } } - void difference_update(const vector_type& other) { - for (auto& item : other) { + void difference_update(const vector_type &other) { + for (auto &item : other) { (void)erase(item); } } // Return the set with elements that are not in the others - ordered_set_type difference(const OrderedSet& other) { + ordered_set_type difference(const OrderedSet &other) { ordered_set_type res(ordered_data_); res.difference_update(other); return res; } - ordered_set_type operator-(const OrderedSet& other) { return difference(other); } + ordered_set_type operator-(const OrderedSet &other) { return difference(other); } - bool contains(const element_type& e) const { return (mapped_data_.find(e) != mapped_data_.end()); } + bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); } // Return the count of an element in set - std::size_t count(const element_type& e) const { return mapped_data_.count(e); } + std::size_t count(const element_type &e) const { return mapped_data_.count(e); } iterator begin() { return ordered_data_.begin(); } iterator end() { return ordered_data_.end(); } diff --git a/mindspore/ccsrc/utils/profile.cc b/mindspore/ccsrc/utils/profile.cc index ba490549f8e..997cc1b56da 100644 --- a/mindspore/ccsrc/utils/profile.cc +++ b/mindspore/ccsrc/utils/profile.cc @@ -33,11 +33,11 @@ namespace { constexpr size_t TIME_INFO_PREFIX_NUM_LEN = 4; const char KEY_PROF_TOTAL[] = "__total__"; -void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent = 0, - std::map* sums = nullptr, const std::string& prefix = ""); +void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent = 0, + std::map *sums = nullptr, const std::string &prefix = ""); -void PrintTimeInfoMap(std::ostringstream& oss, const TimeInfoMap& dict, int indent = 0, - std::map* sums = nullptr, const std::string& prefix = "") { +void PrintTimeInfoMap(std::ostringstream &oss, const TimeInfoMap &dict, int indent = 0, + std::map *sums = nullptr, const std::string &prefix = "") { for (auto iter = dict.begin(); iter != dict.end(); ++iter) { if (iter->second == nullptr) { continue; @@ -62,8 +62,8 @@ void PrintTimeInfoMap(std::ostringstream& oss, const TimeInfoMap& dict, int inde } } -void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent, std::map* sums, - const std::string& prefix) { +void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent, std::map *sums, + const std::string &prefix) { bool need_free = false; if (sums == nullptr) { sums = new (std::nothrow) std::map(); @@ -95,7 +95,7 @@ void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent } oss << "Sums\n"; if (total >= 0.0 + DBL_EPSILON) { - for (auto& iter : *sums) { + for (auto &iter : *sums) { std::string name = iter.first; name.erase(0, TIME_INFO_PREFIX_NUM_LEN); std::size_t pos = 0; @@ -159,7 +159,7 @@ void Profile::Print(void) { // Start a step in the current context with the given name. // Nomes must be unique otherwise the previous record will be overwritten. -ProfContext* Profile::Step(const std::string& name) { +ProfContext *Profile::Step(const std::string &name) { ctx_ptr_ = new (std::nothrow) ProfContext(name, this); if (ctx_ptr_ == nullptr) { MS_LOG(ERROR) << "memory allocation failed"; @@ -170,7 +170,7 @@ ProfContext* Profile::Step(const std::string& name) { // Creates subcontext for a repeated action. // Count should be monotonically increasing. -ProfContext* Profile::Lap(int count) { +ProfContext *Profile::Lap(int count) { std::ostringstream oss; oss << "Cycle " << count; ctx_ptr_ = new (std::nothrow) ProfContext(oss.str(), this); @@ -188,7 +188,7 @@ void Profile::Pop(void) noexcept { ctx_ptr_ = ctx_ptr_->parent_; } -ProfContext::ProfContext(const std::string& name, ProfileBase* const prof) : name_(name), prof_(prof) { +ProfContext::ProfContext(const std::string &name, ProfileBase *const prof) : name_(name), prof_(prof) { // Initialize a subcontext. time_info_ = nullptr; if (prof == nullptr || IsTopContext()) { @@ -227,7 +227,7 @@ void ProfContext::SetTime(double time) noexcept { time_info_->time_ = time; } -void ProfContext::Insert(const std::string& name, const TimeInfo* time) noexcept { +void ProfContext::Insert(const std::string &name, const TimeInfo *time) noexcept { if (time_info_ == nullptr) { time_info_ = new (std::nothrow) TimeInfo(); if (time_info_ == nullptr) { @@ -266,7 +266,7 @@ void ProfContext::Insert(const std::string& name, const TimeInfo* time) noexcept bool ProfContext::IsTopContext() const noexcept { return (prof_ != nullptr) && (this == &prof_->context_); } -ProfTransaction::ProfTransaction(const ProfileBase* prof) { ctx_ = (prof != nullptr ? prof->ctx_ptr_ : nullptr); } +ProfTransaction::ProfTransaction(const ProfileBase *prof) { ctx_ = (prof != nullptr ? prof->ctx_ptr_ : nullptr); } ProfTransaction::~ProfTransaction() { if (ctx_ != nullptr && !ctx_->IsTopContext()) { @@ -275,7 +275,7 @@ ProfTransaction::~ProfTransaction() { ctx_ = nullptr; } -void DumpTime::Record(const std::string& step_name, const double time, const bool is_start) { +void DumpTime::Record(const std::string &step_name, const double time, const bool is_start) { file_ss_ << " {" << std::endl; file_ss_ << " \"name\": " << "\"" << step_name << "\"," << std::endl; @@ -298,7 +298,7 @@ void DumpTime::Record(const std::string& step_name, const double time, const boo void DumpTime::Save() { try { file_out_.open(file_path_, std::ios::trunc | std::ios::out); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "Cannot open file in " << (file_path_); } file_out_ << "{\n"; @@ -317,10 +317,10 @@ struct TimeInfoGroup { std::list::const_iterator> items; }; -static void PrintTimeStat(std::ostringstream& oss, const TimeInfoGroup& group, const std::string& prefix) { +static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, const std::string &prefix) { oss << "------[" << prefix << "] " << std::setw(10) << std::fixed << std::setprecision(6) << group.total_time << std::setw(6) << group.total_count << "\n"; - for (const auto& iter : group.items) { + for (const auto &iter : group.items) { oss << std::setw(5) << std::fixed << std::setprecision(2) << iter->second.time_ / group.total_time * 100 << "% : " << std::setw(12) << std::fixed << std::setprecision(6) << iter->second.time_ << "s : " << std::setw(6) << iter->second.count_ << ": " << iter->first << "\n"; @@ -332,7 +332,7 @@ void MsProfile::Print() { std::vector items = {"substitution.", "renormalize.", "replace.", "match.", "func_graph_cloner_run.", "meta_graph.", "manager."}; std::vector groups(items.size() + 1); - const auto& stat = GetSingleton().time_stat_; + const auto &stat = GetSingleton().time_stat_; // group all time infos for (auto iter = stat.cbegin(); iter != stat.cend(); ++iter) { auto matched_idx = items.size(); diff --git a/mindspore/ccsrc/utils/profile.h b/mindspore/ccsrc/utils/profile.h index 6892b0b4f67..bd3723d5bba 100644 --- a/mindspore/ccsrc/utils/profile.h +++ b/mindspore/ccsrc/utils/profile.h @@ -27,7 +27,7 @@ namespace mindspore { struct TimeInfo; -using TimeInfoMap = std::map; +using TimeInfoMap = std::map; extern double GetTime(); @@ -35,11 +35,11 @@ class ProfileBase; struct TimeInfo { explicit TimeInfo(double time = -1.0) : time_(time), dict_(nullptr), actionNum_(0) {} - TimeInfo(const TimeInfo&) = delete; + TimeInfo(const TimeInfo &) = delete; ~TimeInfo(); double time_; - TimeInfoMap* dict_; + TimeInfoMap *dict_; size_t actionNum_; }; @@ -50,21 +50,21 @@ class ProfContext { friend class ProfTransaction; public: - ProfContext(const std::string& name, ProfileBase* prof); + ProfContext(const std::string &name, ProfileBase *prof); ~ProfContext(); - ProfContext(const ProfContext&) = delete; - ProfContext& operator=(const ProfContext&) = delete; + ProfContext(const ProfContext &) = delete; + ProfContext &operator=(const ProfContext &) = delete; void SetTime(double time) noexcept; - void Insert(const std::string& name, const TimeInfo* time) noexcept; + void Insert(const std::string &name, const TimeInfo *time) noexcept; bool IsTopContext() const noexcept; private: std::string name_; - ProfileBase* prof_; - ProfContext* parent_; - TimeInfo* time_info_; + ProfileBase *prof_; + ProfContext *parent_; + TimeInfo *time_info_; }; class ProfileBase { @@ -76,38 +76,38 @@ class ProfileBase { virtual ~ProfileBase(); virtual void Print(void) {} - virtual ProfContext* Step(const std::string&) { return nullptr; } - virtual ProfContext* Lap(int) { return nullptr; } + virtual ProfContext *Step(const std::string &) { return nullptr; } + virtual ProfContext *Lap(int) { return nullptr; } virtual void Pop(void) {} // top level profile context ProfContext context_; // profile context pointer, act as a stack pointer - ProfContext* ctx_ptr_ = nullptr; + ProfContext *ctx_ptr_ = nullptr; }; class Profile : public ProfileBase { public: Profile() = default; ~Profile() override = default; - Profile(const Profile&) = delete; - Profile& operator=(const Profile&) = delete; + Profile(const Profile &) = delete; + Profile &operator=(const Profile &) = delete; void Print(void) override; - ProfContext* Step(const std::string& name) override; - ProfContext* Lap(int count) override; + ProfContext *Step(const std::string &name) override; + ProfContext *Lap(int count) override; void Pop(void) noexcept override; }; class ProfTransaction { public: - explicit ProfTransaction(const ProfileBase* prof); - explicit ProfTransaction(ProfContext* const ctx) : ctx_(ctx) {} - ProfTransaction(const ProfTransaction&) = delete; + explicit ProfTransaction(const ProfileBase *prof); + explicit ProfTransaction(ProfContext *const ctx) : ctx_(ctx) {} + ProfTransaction(const ProfTransaction &) = delete; ~ProfTransaction(); template - void operator-(const Function& func) { + void operator-(const Function &func) { double start_time = GetTime(); func(); double end_time = GetTime(); @@ -117,17 +117,17 @@ class ProfTransaction { } private: - ProfContext* ctx_ = nullptr; + ProfContext *ctx_ = nullptr; }; class NoProfTransaction { public: - explicit NoProfTransaction(ProfileBase* prof) {} - explicit NoProfTransaction(ProfContext* ctx) {} + explicit NoProfTransaction(ProfileBase *prof) {} + explicit NoProfTransaction(ProfContext *ctx) {} ~NoProfTransaction() = default; template - void operator-(const Function& func) { + void operator-(const Function &func) { func(); } }; @@ -137,20 +137,20 @@ class DumpTime { ~DumpTime() { try { Save(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Cannot save file by profile::DumpTime::save"; } catch (...) { MS_LOG(ERROR) << "Uncaught exception"; } } - DumpTime(const DumpTime&) = delete; - DumpTime& operator=(const DumpTime&) = delete; - static DumpTime& GetInstance() { + DumpTime(const DumpTime &) = delete; + DumpTime &operator=(const DumpTime &) = delete; + static DumpTime &GetInstance() { static DumpTime instance; return instance; } - void set_file_path(const std::string& save_path) { file_path_ = save_path; } - void Record(const std::string& name, const double time, const bool is_start); + void set_file_path(const std::string &save_path) { file_path_ = save_path; } + void Record(const std::string &name, const double time, const bool is_start); void Save(); private: @@ -188,8 +188,8 @@ class MsProfile { static void Reset() { GetSingleton().Clear(); } - static ProfileBase* GetProfile() { - MsProfile& ms_prof = GetSingleton(); + static ProfileBase *GetProfile() { + MsProfile &ms_prof = GetSingleton(); if (ms_prof.profile_ == nullptr) { #ifdef ENABLE_PROFILE ms_prof.profile_ = new Profile(); @@ -199,14 +199,14 @@ class MsProfile { } return ms_prof.profile_; } - static void StatTime(const std::string& id, double time) { GetSingleton().time_stat_[id] += time; } + static void StatTime(const std::string &id, double time) { GetSingleton().time_stat_[id] += time; } static void Print(); private: MsProfile() = default; - static MsProfile& GetSingleton() { + static MsProfile &GetSingleton() { static MsProfile profile; return profile; } @@ -220,7 +220,7 @@ class MsProfile { } std::map time_stat_; // record time and count info from some activity - ProfileBase* profile_ = nullptr; // record hierarchical profile info + ProfileBase *profile_ = nullptr; // record hierarchical profile info }; } // namespace mindspore diff --git a/mindspore/ccsrc/utils/signal.h b/mindspore/ccsrc/utils/signal.h index af7b36a8b5d..9a43e23814d 100644 --- a/mindspore/ccsrc/utils/signal.h +++ b/mindspore/ccsrc/utils/signal.h @@ -24,14 +24,14 @@ namespace mindspore { template -std::function bind_member(Type* instance, Return (Type::*method)(Args...)) { - return [=](Args&&... args) -> Return { return (instance->*method)(std::forward(args)...); }; +std::function bind_member(Type *instance, Return (Type::*method)(Args...)) { + return [=](Args &&... args) -> Return { return (instance->*method)(std::forward(args)...); }; } template class Slot { public: - explicit Slot(const std::function& callback) : callback(callback) {} + explicit Slot(const std::function &callback) : callback(callback) {} ~Slot() {} @@ -42,15 +42,15 @@ template class Signal { public: template - void operator()(Args&&... args) { - for (auto& slot : slots_) { + void operator()(Args &&... args) { + for (auto &slot : slots_) { if (slot->callback != nullptr) { slot->callback(std::forward(args)...); } } } - void add_slot(const std::function& func) { + void add_slot(const std::function &func) { auto slot = std::make_shared>(func); slots_.push_back(slot); } diff --git a/mindspore/ccsrc/utils/symbolic.cc b/mindspore/ccsrc/utils/symbolic.cc index 8764678288d..8ad16e50c84 100644 --- a/mindspore/ccsrc/utils/symbolic.cc +++ b/mindspore/ccsrc/utils/symbolic.cc @@ -22,29 +22,29 @@ namespace mindspore { -std::ostream& operator<<(std::ostream& out, const std::shared_ptr& objPtr) { +std::ostream &operator<<(std::ostream &out, const std::shared_ptr &objPtr) { out << "("; MS_EXCEPTION_IF_NULL(objPtr); - for (auto& iter : objPtr->contents_) { + for (auto &iter : objPtr->contents_) { out << iter.first << ":" << iter.second << ";"; } out << ")"; return out; } -bool EnvInstance::operator==(const EnvInstance& other) const { +bool EnvInstance::operator==(const EnvInstance &other) const { if (Len() != other.Len()) { return false; } bool equal = std::all_of(contents_.begin(), contents_.end(), - [&other](const std::pair& item) -> bool { + [&other](const std::pair &item) -> bool { return other.contents_.find(item.first) != other.contents_.end(); }); return equal; } -bool EnvInstance::operator==(const Value& other) const { +bool EnvInstance::operator==(const Value &other) const { if (other.isa()) { - auto other_env_inst = static_cast(&other); + auto other_env_inst = static_cast(&other); return *this == *other_env_inst; } return false; diff --git a/mindspore/ccsrc/utils/symbolic.h b/mindspore/ccsrc/utils/symbolic.h index 3c712483ee8..a373c235731 100644 --- a/mindspore/ccsrc/utils/symbolic.h +++ b/mindspore/ccsrc/utils/symbolic.h @@ -32,18 +32,18 @@ namespace mindspore { class SymbolicKeyInstance : public Value { public: - SymbolicKeyInstance(const AnfNodePtr& node, const abstract::AbstractBasePtr& abstract) + SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) : node_(node), abstract_(abstract) {} ~SymbolicKeyInstance() override = default; MS_DECLARE_PARENT(SymbolicKeyInstance, Value); AnfNodePtr node() const { return node_; } abstract::AbstractBasePtr abstract() const { return abstract_; } - bool operator==(const SymbolicKeyInstance& other) const { + bool operator==(const SymbolicKeyInstance &other) const { return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); } std::size_t hash() const override { return std::hash{}(node_); } - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr& inst) { + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &inst) { if (inst == nullptr) { os << "[Key][" << "Invalid symbolic key instance" @@ -56,9 +56,9 @@ class SymbolicKeyInstance : public Value { std::string ToString() const override { return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); } - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -106,19 +106,19 @@ using EnvInstanceContentsMap = // with inferred properties. class EnvInstance : public Value { public: - friend std::ostream& operator<<(std::ostream& out, const std::shared_ptr& env); + friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr &env); - explicit EnvInstance(const EnvInstanceContentsMap& contents = {}) : contents_(contents) {} + explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {} ~EnvInstance() override = default; MS_DECLARE_PARENT(EnvInstance, Value); abstract::AbstractBasePtr ToAbstract() override { return std::make_shared(shared_from_base(), std::make_shared()); } - bool operator==(const EnvInstance& other) const; - bool operator==(const Value& other) const override; - EnvInstance(const EnvInstance& v) : Value(v), contents_(v.contents_) {} - EnvInstance(EnvInstance&& v) = default; - EnvInstance& operator=(EnvInstance&& src) noexcept { + bool operator==(const EnvInstance &other) const; + bool operator==(const Value &other) const override; + EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {} + EnvInstance(EnvInstance &&v) = default; + EnvInstance &operator=(EnvInstance &&src) noexcept { if (&src != this) { contents_ = src.contents_; } @@ -126,7 +126,7 @@ class EnvInstance : public Value { }; // Get the sensitivity list for the given key - const Any& Get(const SymbolicKeyInstancePtr& key, const Any& def) const { + const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const { auto iterator = contents_.find(key); if (iterator != contents_.end()) { return iterator->second; @@ -135,14 +135,14 @@ class EnvInstance : public Value { } // Set a value for the given key. - EnvInstance Set(const SymbolicKeyInstancePtr& key, const Any& value) const { + EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const { EnvInstance rval(contents_); rval.contents_[key] = value; return rval; } // Add two EnvInstances. - EnvInstance Add(const EnvInstance& other) const { + EnvInstance Add(const EnvInstance &other) const { EnvInstance rval(contents_); for (auto iter_other : other.contents_) { auto item_self = contents_.find(iter_other.first); diff --git a/mindspore/ccsrc/utils/system/base.h b/mindspore/ccsrc/utils/system/base.h index dace2e71786..4cfb5b312db 100644 --- a/mindspore/ccsrc/utils/system/base.h +++ b/mindspore/ccsrc/utils/system/base.h @@ -108,7 +108,7 @@ constexpr bool kLittleEndian = true; // implement common define function // Get the 32 bits align value -inline uint32 DecodeFixed32(const char* ptr) { +inline uint32 DecodeFixed32(const char *ptr) { uint32 result; if (EOK != memcpy_s(&result, sizeof(result), ptr, sizeof(result))) { MS_LOG(EXCEPTION) << "Call DecodeFixed32 memcpy value failure."; @@ -116,14 +116,14 @@ inline uint32 DecodeFixed32(const char* ptr) { return result; } // Used to fetch a naturally-aligned 32-bit word in little endian byte-order -inline uint32 LE_LOAD32(const uint8_t* p) { return DecodeFixed32(reinterpret_cast(p)); } +inline uint32 LE_LOAD32(const uint8_t *p) { return DecodeFixed32(reinterpret_cast(p)); } // Encode the data to buffer -inline void EncodeFixed32(char* buf, uint32 value) { +inline void EncodeFixed32(char *buf, uint32 value) { if (EOK != memcpy_s(buf, sizeof(value), &value, sizeof(value))) { MS_LOG(EXCEPTION) << "Call EncodeFixed32 memcpy value failure."; } } -inline void EncodeFixed64(char* buf, const unsigned int array_len, int64 value) { +inline void EncodeFixed64(char *buf, const unsigned int array_len, int64 value) { if (sizeof(value) > array_len) { MS_LOG(EXCEPTION) << "Buffer overflow, real size is " << array_len << ", but required " << sizeof(value) << "."; } diff --git a/mindspore/ccsrc/utils/system/crc32c.h b/mindspore/ccsrc/utils/system/crc32c.h index 4411423babe..d23b9ad4639 100644 --- a/mindspore/ccsrc/utils/system/crc32c.h +++ b/mindspore/ccsrc/utils/system/crc32c.h @@ -40,10 +40,10 @@ class Crc32c { ~Crc32c() = default; // Calculate the crc32c value, use the 8 table method - static uint32 MakeCrc32c(uint32 init_crc, const char* data, size_t size); + static uint32 MakeCrc32c(uint32 init_crc, const char *data, size_t size); // retrun the crc32c value(need mask) - static uint32 GetMaskCrc32cValue(const char* data, size_t n) { + static uint32 GetMaskCrc32cValue(const char *data, size_t n) { auto crc = MakeCrc32c(0, data, n); // Rotate right by kRightShift bits and add kMaskDelta(a constant). return ((crc >> kRightShift) | (crc << kLeftShift)) + kMaskDelta; diff --git a/mindspore/ccsrc/utils/system/file_system.cc b/mindspore/ccsrc/utils/system/file_system.cc index aee89d4b7bb..ce27108a39b 100644 --- a/mindspore/ccsrc/utils/system/file_system.cc +++ b/mindspore/ccsrc/utils/system/file_system.cc @@ -25,7 +25,7 @@ namespace system { #if defined(SYSTEM_ENV_POSIX) // Implement the Posix file systen -WriteFilePtr PosixFileSystem::CreateWriteFile(const string& file_name) { +WriteFilePtr PosixFileSystem::CreateWriteFile(const string &file_name) { if (file_name.empty()) { MS_LOG(ERROR) << "Create write file failed because the file name is null."; return nullptr; @@ -43,7 +43,7 @@ WriteFilePtr PosixFileSystem::CreateWriteFile(const string& file_name) { return fp; } -bool PosixFileSystem::FileExist(const string& file_name) { +bool PosixFileSystem::FileExist(const string &file_name) { if (file_name.empty()) { MS_LOG(WARNING) << "The file name is null."; return false; @@ -56,7 +56,7 @@ bool PosixFileSystem::FileExist(const string& file_name) { return true; } -bool PosixFileSystem::DeleteFile(const string& file_name) { +bool PosixFileSystem::DeleteFile(const string &file_name) { if (file_name.empty()) { MS_LOG(WARNING) << "The file name is null."; return false; @@ -70,7 +70,7 @@ bool PosixFileSystem::DeleteFile(const string& file_name) { } static const int DEFAULT_MKDIR_MODE = 0700; -bool PosixFileSystem::CreateDir(const string& dir_name) { +bool PosixFileSystem::CreateDir(const string &dir_name) { if (dir_name.empty()) { MS_LOG(WARNING) << "The directory name is null."; return false; @@ -83,7 +83,7 @@ bool PosixFileSystem::CreateDir(const string& dir_name) { return true; } -bool PosixFileSystem::DeleteDir(const string& dir_name) { +bool PosixFileSystem::DeleteDir(const string &dir_name) { if (dir_name.empty()) { MS_LOG(WARNING) << "The directory name is null."; return false; diff --git a/mindspore/ccsrc/utils/system/file_system.h b/mindspore/ccsrc/utils/system/file_system.h index ef0cf885be1..ed9db874c8f 100644 --- a/mindspore/ccsrc/utils/system/file_system.h +++ b/mindspore/ccsrc/utils/system/file_system.h @@ -45,25 +45,25 @@ class FileSystem { virtual ~FileSystem() = default; // Create a new read/write file - virtual WriteFilePtr CreateWriteFile(const string& file_name) = 0; + virtual WriteFilePtr CreateWriteFile(const string &file_name) = 0; // Check the file is exist? - virtual bool FileExist(const string& file_name) = 0; + virtual bool FileExist(const string &file_name) = 0; // Delete the file - virtual bool DeleteFile(const string& file_name) = 0; + virtual bool DeleteFile(const string &file_name) = 0; // Create a directory - virtual bool CreateDir(const string& dir_name) = 0; + virtual bool CreateDir(const string &dir_name) = 0; // Delete the specified directory - virtual bool DeleteDir(const string& dir_name) = 0; + virtual bool DeleteDir(const string &dir_name) = 0; }; // A file that can be read and write class WriteFile { public: - explicit WriteFile(const string& file_name) : file_name_(file_name) {} + explicit WriteFile(const string &file_name) : file_name_(file_name) {} virtual ~WriteFile() = default; @@ -71,7 +71,7 @@ class WriteFile { virtual bool Open() = 0; // append the content to file - virtual bool Write(const std::string& data) { + virtual bool Write(const std::string &data) { MS_LOG(WARNING) << "Attention: Maybe not call the function."; return true; } @@ -101,27 +101,27 @@ class PosixFileSystem : public FileSystem { ~PosixFileSystem() override = default; // create a new write file - WriteFilePtr CreateWriteFile(const string& file_name) override; + WriteFilePtr CreateWriteFile(const string &file_name) override; // check the file is exist? - bool FileExist(const string& file_name) override; + bool FileExist(const string &file_name) override; // delete the file - bool DeleteFile(const string& file_name) override; + bool DeleteFile(const string &file_name) override; // Create a Directory - bool CreateDir(const string& dir_name) override; + bool CreateDir(const string &dir_name) override; // Delete the specified directory. - bool DeleteDir(const string& dir_name) override; + bool DeleteDir(const string &dir_name) override; }; // A file that can be read and write for posix class PosixWriteFile : public WriteFile { public: - explicit PosixWriteFile(const string& file_name) : WriteFile(file_name), file_(nullptr) {} - PosixWriteFile(const PosixWriteFile&); - PosixWriteFile& operator=(const PosixWriteFile&); + explicit PosixWriteFile(const string &file_name) : WriteFile(file_name), file_(nullptr) {} + PosixWriteFile(const PosixWriteFile &); + PosixWriteFile &operator=(const PosixWriteFile &); ~PosixWriteFile() override { try { @@ -129,7 +129,7 @@ class PosixWriteFile : public WriteFile { (void)fclose(file_); file_ = nullptr; } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file."; } catch (...) { MS_LOG(ERROR) << "Non standard exception when closing file."; @@ -159,7 +159,7 @@ class PosixWriteFile : public WriteFile { return true; } - bool Write(const std::string& data) override { + bool Write(const std::string &data) override { MS_LOG(DEBUG) << "Write data(" << data.size() << ") to file(" << this->file_name_ << ")."; size_t r = fwrite(data.data(), 1, data.size(), file_); if (r != data.size()) { @@ -194,7 +194,7 @@ class PosixWriteFile : public WriteFile { bool Sync() override { return Flush(); } private: - FILE* file_; + FILE *file_; }; #endif diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index eac1b862739..f05eda69bfb 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -213,7 +213,7 @@ const std::set kOptOperatorSet = { const std::set kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; -static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { +static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { if (access(file_name.c_str(), F_OK) != 0) { MS_LOG(DEBUG) << "File `" << file_name << "` does not exist."; return; diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index d7d5a4c0964..ae052770ff9 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -47,7 +47,7 @@ void ClearConvertCache() { g_ConvertCache.clear(); } // lst: list of nodes (the segment) // users: dict mapping each node to its users (globally) // seen: set of nodes that are part of the segment -AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, const std::vector& seen) { +AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector &seen) { AnfNodePtrList output; if (users.size() == 0) { return output; @@ -57,7 +57,7 @@ AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, c std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr { auto usersn = users.find(n); bool is_referred_out_of_segment = std::any_of( - std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair& u) -> bool { + std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair &u) -> bool { return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen); }); if (n->isa() && is_referred_out_of_segment) { @@ -78,7 +78,7 @@ AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, c return output; } -std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList& lst) { +std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { auto fg = std::make_shared(); AnfNodePtrList inputs; AnfNodePtrToAnfNodePtrMap eqv; @@ -86,7 +86,7 @@ std::tuple TransformSegmentToAnfGr MS_LOG(EXCEPTION) << "Input anf node list is empty"; } - auto ref = [&eqv, &inputs, &fg](const AnfNodePtr& a) -> AnfNodePtr { + auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { if (a->isa() && !IsValueNode(a)) { eqv[a] = a; } else if (eqv.find(a) == eqv.end()) { @@ -102,7 +102,7 @@ std::tuple TransformSegmentToAnfGr if (!n->isa()) { MS_LOG(EXCEPTION) << "Inst is not CNode"; } - auto& inps = n->cast()->inputs(); + auto &inps = n->cast()->inputs(); if (inps.empty()) { MS_LOG(EXCEPTION) << "Input is empty"; @@ -120,13 +120,13 @@ std::tuple TransformSegmentToAnfGr std::vector eqv_keys; (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), - [](const std::pair& elem) -> AnfNodePtr { return elem.first; }); + [](const std::pair &elem) -> AnfNodePtr { return elem.first; }); auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); std::vector output_args; output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args), - [&eqv](const AnfNodePtr& o) -> AnfNodePtr { return eqv[o]; }); + [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); // Set output for AnfGraph auto fg_output = fg->NewCNode(output_args); @@ -148,7 +148,7 @@ std::tuple TransformSegmentToAnfGr // This implementation will convert the nodes into a subgraph // that will run using the MsVM. template -LinConvertResult Convert(const AnfNodePtrList& lst) { +LinConvertResult Convert(const AnfNodePtrList &lst) { auto cached = g_ConvertCache.find(lst); if (cached != g_ConvertCache.end()) { return cached->second; @@ -168,7 +168,7 @@ LinConvertResult Convert(const AnfNodePtrList& lst) { std::shared_ptr vm = std::make_shared(); result.run = - std::make_shared([fg, vm](const VectorRef& args) -> VectorRef { return vm->RunGraph(fg, args); }); + std::make_shared([fg, vm](const VectorRef &args) -> VectorRef { return vm->RunGraph(fg, args); }); result.inputs = inputs; result.outputs = outputs; result.graph_id = UINT32_MAX; diff --git a/mindspore/ccsrc/vm/segment_runner.h b/mindspore/ccsrc/vm/segment_runner.h index 112a770de8d..8ea87da50c3 100644 --- a/mindspore/ccsrc/vm/segment_runner.h +++ b/mindspore/ccsrc/vm/segment_runner.h @@ -43,7 +43,7 @@ struct LinConvertResult { uint32_t graph_id; }; -using LinkFuncType = std::function; +using LinkFuncType = std::function; using ConvertCache = std::unordered_map; extern LinkFuncType MsVmConvert; extern LinkFuncType GeVmConvert; @@ -53,7 +53,7 @@ extern std::set backend_list; void ClearConvertCache(); -std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList& lst); +std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst); } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 92976e0ddb9..1c3c917daef 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -41,12 +41,12 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple}; -const std::vector& GetMsNonlinearOps() { +const std::vector &GetMsNonlinearOps() { static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch}; return ms_nonlinear_ops; } -CompileGraph::CompileGraph(const BackendPtr& backend, const std::vector& cut_list) +CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) : backend_(backend), cut_list_(cut_list) { MS_EXCEPTION_IF_NULL(backend_); lin_convert_ = backend_->convert_fn(); @@ -61,11 +61,11 @@ CompileGraph::CompileGraph(const BackendPtr& backend, const std::vectorisa()) { auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); if (inputs.empty()) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } @@ -76,7 +76,7 @@ bool CompileGraph::IsCut(const AnfNodePtr& node) { } PrimitivePtr node_prim = GetValueNode(fn); - for (auto& prim : cut_list_) { + for (auto &prim : cut_list_) { MS_EXCEPTION_IF_NULL(prim); if (prim->name() == node_prim->name()) { return true; @@ -97,14 +97,14 @@ bool CompileGraph::IsCut(const AnfNodePtr& node) { return false; } -VectorRef CompileGraph::SplitNodes(const FuncGraphPtr& graph) { +VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); VectorRef splits; VectorRef split; std::vector nodes = TopoSort(graph->get_return()); MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); - for (auto& node : nodes) { + for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (IsCut(node)) { MS_LOG(DEBUG) << "Cut node:" << node->DebugString(10) << ", size:" << split.size(); @@ -123,7 +123,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr& graph) { } // Push the value node on the stack. -void CompileGraph::Push(const AnfNodePtr& node) { +void CompileGraph::Push(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (slots_.count(node) > 0) { MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString() @@ -135,25 +135,25 @@ void CompileGraph::Push(const AnfNodePtr& node) { set_height(height_ + 1); } -void CompileGraph::AddInst(const Instruction& inst, const int& arg) { +void CompileGraph::AddInst(const Instruction &inst, const int &arg) { VectorRef args; args.push_back(arg); AddInst(inst, args); } -void CompileGraph::AddInst(const Instruction& inst, const ValuePtr& arg) { +void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) { VectorRef args; args.push_back(arg); AddInst(inst, args); } -void CompileGraph::AddInst(const Instruction& inst, const VectorRef& args) { +void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) { inst_.push_back(std::make_pair(inst, args)); } // Gets the stack reference for the node value. If the node is a constant, // it may actually cause the push in to not be mentioned before. -int CompileGraph::Ref(const AnfNodePtr& node) { +int CompileGraph::Ref(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_; if (slots_.count(node) == 0 && node->isa()) { @@ -176,7 +176,7 @@ int CompileGraph::Ref(const AnfNodePtr& node) { } // Make sure the value of node is at the top of the stack. -void CompileGraph::AddInput(const AnfNodePtr& node) { +void CompileGraph::AddInput(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (slots_.count(node) == 0) { MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true); @@ -190,7 +190,7 @@ void CompileGraph::AddInput(const AnfNodePtr& node) { // Call back effect in stack void CompileGraph::Ret(int nargs) { set_height(height_ - nargs); } -void CompileGraph::PushParameters(const FuncGraphPtr& graph) { +void CompileGraph::PushParameters(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector parameters = graph->parameters(); for (size_t i = parameters.size(); i != 0; i--) { @@ -199,7 +199,7 @@ void CompileGraph::PushParameters(const FuncGraphPtr& graph) { } } -int CompileGraph::LinConvert(const FuncGraphPtr& graph, const AnfNodePtrList& node_list) { +int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) { MS_LOG(DEBUG) << "LinConvert start"; LinConvertResult result; @@ -227,14 +227,14 @@ int CompileGraph::LinConvert(const FuncGraphPtr& graph, const AnfNodePtrList& no } } AddExternal(result); - for (auto& o : result.outputs) { + for (auto &o : result.outputs) { Push(o); } return RET_SUCCESS; } -void CompileGraph::AddSinkSwitch(const CNodePtr& node) { +void CompileGraph::AddSinkSwitch(const CNodePtr &node) { MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString(); if (backend_->is_multi_graph_sink()) { VectorRef args; @@ -255,7 +255,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr& node) { } } -int CompileGraph::InterpretNode(const FuncGraphPtr& graph, const CNodePtr& node) { +int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true); std::vector node_inputs = node->inputs(); @@ -293,7 +293,7 @@ int CompileGraph::InterpretNode(const FuncGraphPtr& graph, const CNodePtr& node) return RET_SUCCESS; } -void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr& graph) { +void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) { auto ret = LinConvert(graph, {}); if (ret == RET_FAILED) { MS_LOG(EXCEPTION) << "MultiGraphRun failed."; @@ -301,20 +301,20 @@ void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr& graph) { AddReturn(nullptr); } -bool CompileGraph::SplitGraph(const FuncGraphPtr& graph) { +bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "Start split graph"; MS_EXCEPTION_IF_NULL(graph); VectorRef splits = SplitNodes(graph); MS_LOG(DEBUG) << "Split nodes size:" << splits.size(); - for (auto& split : splits) { + for (auto &split : splits) { int ret = RET_SUCCESS; if (utils::isa(split)) { MS_LOG(DEBUG) << "Start a extern LinConvert"; std::vector args; auto vec_ref = utils::cast(split); (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), - [](const BaseRef& v) { return utils::cast(v); }); + [](const BaseRef &v) { return utils::cast(v); }); ret = LinConvert(graph, args); MS_LOG(DEBUG) << "End a extern LinConvert"; if (ret == RET_FAILED) { @@ -340,12 +340,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr& graph) { return true; } -InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr& graph) { +InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) { InstSet inst = Run(graph); return inst; } -InstSet CompileGraph::Run(const FuncGraphPtr& graph) { +InstSet CompileGraph::Run(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(DEBUG) << "Compile start graph: " << graph->ToString(); @@ -378,7 +378,7 @@ void CompileGraph::AddPadStack(int param_height) { } } -void CompileGraph::AddTailCall(const AnfNodePtr& fn, size_t size) { +void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) { VectorRef args; args.emplace_back(Ref(fn)); args.emplace_back(height_); @@ -387,7 +387,7 @@ void CompileGraph::AddTailCall(const AnfNodePtr& fn, size_t size) { AddInst(Instruction::kTailCall, args); } -void CompileGraph::AddPartial(const CNodePtr& node) { +void CompileGraph::AddPartial(const CNodePtr &node) { auto inputs = node->inputs(); VectorRef args; for (size_t i = 1; i < inputs.size(); i++) { @@ -396,7 +396,7 @@ void CompileGraph::AddPartial(const CNodePtr& node) { AddInst(Instruction::kPartial, args); } -void CompileGraph::AddMakeTuple(const CNodePtr& node) { +void CompileGraph::AddMakeTuple(const CNodePtr &node) { auto inputs = node->inputs(); VectorRef args; for (size_t i = 1; i < inputs.size(); i++) { @@ -405,7 +405,7 @@ void CompileGraph::AddMakeTuple(const CNodePtr& node) { AddInst(Instruction::kTuple, args); } -void CompileGraph::AddSwitch(const CNodePtr& node) { +void CompileGraph::AddSwitch(const CNodePtr &node) { auto inputs = node->inputs(); if (inputs.size() < 4) { MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4"; @@ -420,7 +420,7 @@ void CompileGraph::AddSwitch(const CNodePtr& node) { AddInst(Instruction::kSwitch, args); } -void CompileGraph::AddReturn(const CNodePtr& node) { +void CompileGraph::AddReturn(const CNodePtr &node) { VectorRef args; if (backend_->simu_flag()) { args.emplace_back(Ref(backend_->final_output())); @@ -431,7 +431,7 @@ void CompileGraph::AddReturn(const CNodePtr& node) { AddInst(Instruction::kReturn, args); } -void CompileGraph::AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim) { +void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) { auto inputs = node->inputs(); VectorRef args; args.push_back(prim); @@ -441,7 +441,7 @@ void CompileGraph::AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim) AddInst(Instruction::kPrim, args); } -int CompileGraph::AddCall(const FuncGraphPtr& graph, const CNodePtr& node) { +int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) { auto node_inputs = node->inputs(); AnfNodePtr fn = node_inputs[0]; (void)Ref(fn); @@ -459,7 +459,7 @@ int CompileGraph::AddCall(const FuncGraphPtr& graph, const CNodePtr& node) { return RET_SUCCESS; } -void CompileGraph::AddExternal(const LinConvertResult& result) { +void CompileGraph::AddExternal(const LinConvertResult &result) { VectorRef args; args.push_back(result.run); args.push_back(result.simu_run); @@ -471,16 +471,16 @@ void CompileGraph::AddExternal(const LinConvertResult& result) { } void TraverseGraphMap( - const FuncGraphManagerPtr& manager_ptr, FuncGraphTransaction* const tr, const FuncGraphToAnfNodeCounterMap& cts, - const std::function(const PrimitivePtr, const AbstractFunctionPtr)>& get_prim_graph) { + const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts, + const std::function(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(tr); - for (const auto& ct_graphs : cts) { - for (const auto& ct_any : ct_graphs.second) { + for (const auto &ct_graphs : cts) { + for (const auto &ct_any : ct_graphs.second) { AnfNodePtr const_primitive_node = ct_any.first; if (const_primitive_node != nullptr && IsValueNode(const_primitive_node)) { auto users = manager_ptr->node_users()[const_primitive_node]; - for (auto& use : users) { + for (auto &use : users) { CNodePtr node = use.first->cast(); MS_EXCEPTION_IF_NULL(node); int key = use.second; @@ -503,12 +503,12 @@ void TraverseGraphMap( } } -FuncGraphPtr WrapPrimitives(const FuncGraphPtr& graph) { +FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); FuncGraphManagerPtr manager_ptr = graph->manager(); MS_EXCEPTION_IF_NULL(manager_ptr); MapPrimTypeFuncGraph prim_graphs; - auto get_prim_graph = [&](const PrimitivePtr& prim, const AbstractFunctionPtr& type) { + auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) { PrimTypePair prim_type = std::make_pair(prim, type); if (prim_graphs.end() == prim_graphs.find(prim_type)) { FuncGraphPtr g = std::make_shared(); @@ -536,13 +536,13 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr& graph) { }; FuncGraphTransaction tr = manager_ptr->Transact(); - auto& cts = manager_ptr->valuenodes(); + auto &cts = manager_ptr->valuenodes(); TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph); return graph; } -CompileGraphs::CompileGraphs(const BackendPtr& backend, const std::vector& cut_list) : backend_(backend) { +CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector &cut_list) : backend_(backend) { MS_EXCEPTION_IF_NULL(backend); MS_LOG(DEBUG) << "Start vm: " << backend->name(); transform_ = std::make_shared(backend, cut_list); @@ -550,12 +550,12 @@ CompileGraphs::CompileGraphs(const BackendPtr& backend, const std::vectormanager(); MS_EXCEPTION_IF_NULL(graph_manager); FuncGraphSet graphs = graph_manager->func_graphs(); - for (auto& g : graphs) { + for (auto &g : graphs) { mapping_[g] = static_cast(insts_.size()); if (transform_ != nullptr) { InstSet insts = transform_->Run(g); @@ -568,7 +568,7 @@ void CompileGraphs::Compile(const FuncGraphPtr& graph) { } // Link instructions from multiple function graphs together. -FinalVMPtr CompileGraphs::Link(const FuncGraphPtr& graph) { +FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "Start"; for (std::size_t i = 0; i < insts_.size(); i++) { InstType inst = insts_[i]; @@ -600,7 +600,7 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr& graph) { } // Convert all graphs to unlinked instructions and link them. -FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr& graph) { +FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(DEBUG) << "Start"; Reset(); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 290af100491..711c1777ab9 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -42,26 +42,26 @@ extern const char kGeVm[]; // A sub namespace in ME to support compile related definition. namespace compile { extern std::vector nonlinear_ops; -const std::vector& GetMsNonlinearOps(); +const std::vector &GetMsNonlinearOps(); -using VmEvalFunc = std::function; -using VmEvalFuncPtr = std::shared_ptr>; +using VmEvalFunc = std::function; +using VmEvalFuncPtr = std::shared_ptr>; class CompileGraph { public: - explicit CompileGraph(const BackendPtr& backend, const std::vector& cut_list = nonlinear_ops); + explicit CompileGraph(const BackendPtr &backend, const std::vector &cut_list = nonlinear_ops); ~CompileGraph() = default; - InstSet Run(const FuncGraphPtr& func_graph); - InstSet GenMultiGraphsSinkInst(const FuncGraphPtr& graph); - bool IsCut(const AnfNodePtr& node); - void Push(const AnfNodePtr& node); - void Tie(const AnfNodePtr& n1, const AnfNodePtr& n2) { slots_[n2] = slots_[n1]; } + InstSet Run(const FuncGraphPtr &func_graph); + InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph); + bool IsCut(const AnfNodePtr &node); + void Push(const AnfNodePtr &node); + void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } void Ret(int nargs); - void GenMultiGraphsRun(const FuncGraphPtr& graph); - int Ref(const AnfNodePtr& node); - VectorRef SplitNodes(const FuncGraphPtr& func_graph); + void GenMultiGraphsRun(const FuncGraphPtr &graph); + int Ref(const AnfNodePtr &node); + VectorRef SplitNodes(const FuncGraphPtr &func_graph); void set_height(int h) { height_ = h; @@ -78,24 +78,24 @@ class CompileGraph { } private: - void PushParameters(const FuncGraphPtr& func_graph); - bool SplitGraph(const FuncGraphPtr& func_graph); - int LinConvert(const FuncGraphPtr& func_graph, const AnfNodePtrList& node_list); - int InterpretNode(const FuncGraphPtr& func_graph, const CNodePtr& node); - int AddCall(const FuncGraphPtr& graph, const CNodePtr& node); - void AddSinkSwitch(const CNodePtr& node); + void PushParameters(const FuncGraphPtr &func_graph); + bool SplitGraph(const FuncGraphPtr &func_graph); + int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list); + int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); + int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); + void AddSinkSwitch(const CNodePtr &node); void AddPadStack(int param_height); - void AddTailCall(const AnfNodePtr& fn, size_t size); - void AddPartial(const CNodePtr& node); - void AddMakeTuple(const CNodePtr& node); - void AddSwitch(const CNodePtr& node); - void AddReturn(const CNodePtr& node); - void AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim); - void AddInput(const AnfNodePtr& node); - void AddExternal(const LinConvertResult& result); - void AddInst(const Instruction& inst, const int& arg); - void AddInst(const Instruction& inst, const ValuePtr& arg); - void AddInst(const Instruction& inst, const VectorRef& args); + void AddTailCall(const AnfNodePtr &fn, size_t size); + void AddPartial(const CNodePtr &node); + void AddMakeTuple(const CNodePtr &node); + void AddSwitch(const CNodePtr &node); + void AddReturn(const CNodePtr &node); + void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); + void AddInput(const AnfNodePtr &node); + void AddExternal(const LinConvertResult &result); + void AddInst(const Instruction &inst, const int &arg); + void AddInst(const Instruction &inst, const ValuePtr &arg); + void AddInst(const Instruction &inst, const VectorRef &args); BackendPtr backend_; LinkFuncType lin_convert_; @@ -112,7 +112,7 @@ using CompileGraphPtr = std::shared_ptr; // CompileGraphs is used to Convert a graph cluster into instruction lists. class CompileGraphs { public: - explicit CompileGraphs(const BackendPtr& backend, const std::vector& cut_list = nonlinear_ops); + explicit CompileGraphs(const BackendPtr &backend, const std::vector &cut_list = nonlinear_ops); ~CompileGraphs() = default; @@ -121,9 +121,9 @@ class CompileGraphs { mapping_.clear(); } - void Compile(const FuncGraphPtr& func_graph); - FinalVMPtr Link(const FuncGraphPtr& func_graph); - FinalVMPtr CompileAndLink(const FuncGraphPtr& func_graph); + void Compile(const FuncGraphPtr &func_graph); + FinalVMPtr Link(const FuncGraphPtr &func_graph); + FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); private: InstSet insts_; diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index 493873b0bcf..95ceceb67f9 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -32,29 +32,29 @@ namespace compile { // Arguments: // fn_: Callable function. // args_: Sequence of function args. -StructPartial::StructPartial(int fn, const VectorRef& args) : fn_(fn), args_(args) {} +StructPartial::StructPartial(int fn, const VectorRef &args) : fn_(fn), args_(args) {} -std::ostream& operator<<(std::ostream& os, const StructPartial& other) { +std::ostream &operator<<(std::ostream &os, const StructPartial &other) { os << "partial(" << other.fn_ << ", " << other.args_.ToString() << ")"; return os; } -bool operator==(const StructPartial& lhs, const StructPartial& rhs) { +bool operator==(const StructPartial &lhs, const StructPartial &rhs) { return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_); } -StructSimuSwitch::StructSimuSwitch(const BaseRef& fn, const BaseRef& value) : fn_(fn), value_(value) {} +StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {} -std::ostream& operator<<(std::ostream& os, const StructSimuSwitch& other) { +std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other) { os << "SimulSwitch(" << other.fn_.ToString() << ", " << other.value_.ToString() << ")"; return os; } -bool operator==(const StructSimuSwitch& lhs, const StructSimuSwitch& rhs) { +bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs) { return (lhs.fn_ == rhs.fn_ && lhs.value_ == rhs.value_); } -std::ostream& operator<<(std::ostream& os, const SwitchCondStatus& other) { +std::ostream &operator<<(std::ostream &os, const SwitchCondStatus &other) { os << "SwitchCondStatus(" << static_cast(other) << ")"; return os; } @@ -66,13 +66,13 @@ std::ostream& operator<<(std::ostream& os, const SwitchCondStatus& other) { // retp_: The call stack. // pc_: program counter (next instruction) // sp_: stack pointer (for the value stack) -FinalVM::FinalVM(const InstSet& insts, const BackendPtr& backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) { +FinalVM::FinalVM(const InstSet &insts, const BackendPtr &backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) { MS_LOG(DEBUG) << "InstSet size:" << insts_.size(); insts_stack_.emplace_back(BaseRef()); retp_.push(-1); } -void FinalVM::Push(const BaseRef& v) { +void FinalVM::Push(const BaseRef &v) { MS_LOG(DEBUG) << "Push " << v.ToString() << " sp_:" << sp_; insts_stack_[IntToSize(sp_++)] = v; } @@ -140,7 +140,7 @@ void FinalVM::Popsp() { } } -void FinalVM::DoJmp(const BaseRef& jmp_orig) { +void FinalVM::DoJmp(const BaseRef &jmp_orig) { MS_LOG(DEBUG) << "Start"; BaseRef jmp = jmp_orig; @@ -173,7 +173,7 @@ void FinalVM::DoJmp(const BaseRef& jmp_orig) { MS_LOG(DEBUG) << "End do jump pc_:" << pc_; } -BaseRef FinalVM::Eval(const VectorRef& args) { +BaseRef FinalVM::Eval(const VectorRef &args) { MS_LOG(DEBUG) << "Start: " << args.size(); insts_stack_.clear(); insts_stack_.resize(args.size()); @@ -212,7 +212,7 @@ BaseRef FinalVM::Eval(const VectorRef& args) { return insts_stack_[0]; } -void FinalVM::InstCall(const VectorRef& args) { +void FinalVM::InstCall(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -228,7 +228,7 @@ void FinalVM::InstCall(const VectorRef& args) { MS_LOG(DEBUG) << "Instcall end sp :" << sp_; } -void FinalVM::InstTailCall(const VectorRef& args) { +void FinalVM::InstTailCall(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 3; if (args.size() != args_size) { @@ -258,7 +258,7 @@ void FinalVM::InstTailCall(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstSwitchReturn(const VectorRef& args) { +void FinalVM::InstSwitchReturn(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; if (args.size() != 1) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << "."; @@ -268,7 +268,7 @@ void FinalVM::InstSwitchReturn(const VectorRef& args) { Popsp(); } -void FinalVM::InstReturn(const VectorRef& args) { +void FinalVM::InstReturn(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 2; if (args.size() != args_size) { @@ -291,7 +291,7 @@ void FinalVM::InstReturn(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPartial(const VectorRef& args) { +void FinalVM::InstPartial(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() < args_size) { @@ -306,12 +306,12 @@ void FinalVM::InstPartial(const VectorRef& args) { std::vector outs(args.size() - 1); (void)std::transform(args.begin() + 1, args.end(), outs.begin(), - [&, this](const BaseRef& a) { return Ref(utils::cast(a)); }); + [&, this](const BaseRef &a) { return Ref(utils::cast(a)); }); Push(std::make_shared(fn, VectorRef(outs))); MS_LOG(DEBUG) << "End"; } -void FinalVM::InstSimuSwitch(const VectorRef& args) { +void FinalVM::InstSimuSwitch(const VectorRef &args) { const size_t args_size = 4; if (args.size() != args_size) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " @@ -365,7 +365,7 @@ void FinalVM::InstSimuSwitch(const VectorRef& args) { } } -void FinalVM::InstRealSwitch(const VectorRef& args) { +void FinalVM::InstRealSwitch(const VectorRef &args) { const size_t args_size = 3; if (args.size() != args_size) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " @@ -392,7 +392,7 @@ void FinalVM::InstRealSwitch(const VectorRef& args) { } } -void FinalVM::InstSwitch(const VectorRef& args) { +void FinalVM::InstSwitch(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; if (backend_->is_multi_graph_sink()) { InstSimuSwitch(args); @@ -401,7 +401,7 @@ void FinalVM::InstSwitch(const VectorRef& args) { } } -void FinalVM::InstTuple(const VectorRef& args) { +void FinalVM::InstTuple(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; VectorRef tuple; auto iter = args.begin(); @@ -413,7 +413,7 @@ void FinalVM::InstTuple(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPush(const VectorRef& args) { +void FinalVM::InstPush(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -427,7 +427,7 @@ void FinalVM::InstPush(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstInput(const VectorRef& args) { +void FinalVM::InstInput(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -441,7 +441,7 @@ void FinalVM::InstInput(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPadStack(const VectorRef& args) { +void FinalVM::InstPadStack(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -461,7 +461,7 @@ void FinalVM::InstPadStack(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstExternal(const VectorRef& args) { +void FinalVM::InstExternal(const VectorRef &args) { MS_LOG(DEBUG) << "Start:" << args.size(); if (args.empty()) { @@ -490,14 +490,14 @@ void FinalVM::InstExternal(const VectorRef& args) { auto outs = (*fn)(tuple); MS_LOG(DEBUG) << "'fn' out size:" << outs.size(); - for (auto& o : outs) { + for (auto &o : outs) { MS_LOG(DEBUG) << "InstExternal value:" << o.ToString(); Push(o); } MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPushPrim(const VectorRef& args) { +void FinalVM::InstPushPrim(const VectorRef &args) { MS_LOG(DEBUG) << "Start: " << args.size(); const size_t args_size = 2; if (args.size() < args_size) { diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index 3e1e5b5c085..eab726a9b7f 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -53,14 +53,14 @@ enum Instruction { using InstType = std::pair; using InstSet = std::vector; -using InstFunctionMap = std::map>; +using InstFunctionMap = std::map>; const std::vector inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", "input", "external", "push", "primitive", "graph", "pad_stack"}; class StructPartial : public Base { public: // Initialize StructPartial. - StructPartial(int fn, const VectorRef& args); + StructPartial(int fn, const VectorRef &args); virtual ~StructPartial() = default; MS_DECLARE_PARENT(StructPartial, Base) @@ -69,12 +69,12 @@ class StructPartial : public Base { VectorRef args_; }; -std::ostream& operator<<(std::ostream& os, const StructPartial& other); -bool operator==(const StructPartial& lhs, const StructPartial& rhs); +std::ostream &operator<<(std::ostream &os, const StructPartial &other); +bool operator==(const StructPartial &lhs, const StructPartial &rhs); class StructSimuSwitch : public Base { public: - StructSimuSwitch(const BaseRef& fn, const BaseRef& value); + StructSimuSwitch(const BaseRef &fn, const BaseRef &value); virtual ~StructSimuSwitch() = default; MS_DECLARE_PARENT(StructSimuSwitch, Base) @@ -83,43 +83,43 @@ class StructSimuSwitch : public Base { BaseRef value_; }; -std::ostream& operator<<(std::ostream& os, const StructSimuSwitch& other); -bool operator==(const StructSimuSwitch& lhs, const StructSimuSwitch& rhs); +std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other); +bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs); class FinalVM { public: // Create a VM with the specified instructions and backend. - explicit FinalVM(const InstSet& insts, const BackendPtr& backend); + explicit FinalVM(const InstSet &insts, const BackendPtr &backend); virtual ~FinalVM() = default; - BaseRef Eval(const VectorRef& args); - void InstCall(const VectorRef& args); - void InstTailCall(const VectorRef& args); - void InstReturn(const VectorRef& args); - void InstPartial(const VectorRef& args); - void InstSwitch(const VectorRef& args); - void InstSimuSwitch(const VectorRef& args); - void InstRealSwitch(const VectorRef& args); - void InstTuple(const VectorRef& args); - void InstPush(const VectorRef& args); - void InstInput(const VectorRef& args); - void InstPadStack(const VectorRef& args); - void InstExternal(const VectorRef& args); - void InstPushPrim(const VectorRef& args); - void InstSwitchReturn(const VectorRef& args); - void set_insts(const InstSet& value) { insts_ = value; } + BaseRef Eval(const VectorRef &args); + void InstCall(const VectorRef &args); + void InstTailCall(const VectorRef &args); + void InstReturn(const VectorRef &args); + void InstPartial(const VectorRef &args); + void InstSwitch(const VectorRef &args); + void InstSimuSwitch(const VectorRef &args); + void InstRealSwitch(const VectorRef &args); + void InstTuple(const VectorRef &args); + void InstPush(const VectorRef &args); + void InstInput(const VectorRef &args); + void InstPadStack(const VectorRef &args); + void InstExternal(const VectorRef &args); + void InstPushPrim(const VectorRef &args); + void InstSwitchReturn(const VectorRef &args); + void set_insts(const InstSet &value) { insts_ = value; } protected: BaseRef Ref(int i); - void Push(const BaseRef& v); + void Push(const BaseRef &v); void Pop(int n = 1); void MoveStack(int nitems, int height); void Pushp(); void Popp(); void Pushsp(); void Popsp(); - void DoJmp(const BaseRef& jmp); + void DoJmp(const BaseRef &jmp); private: InstSet insts_; @@ -130,18 +130,18 @@ class FinalVM { int sp_; BackendPtr backend_; const InstFunctionMap inst_function_map = { - {Instruction::kCall, [this](const VectorRef& args) { InstCall(args); }}, - {Instruction::kTailCall, [this](const VectorRef& args) { InstTailCall(args); }}, - {Instruction::kReturn, [this](const VectorRef& args) { InstReturn(args); }}, - {Instruction::kPartial, [this](const VectorRef& args) { InstPartial(args); }}, - {Instruction::kSwitch, [this](const VectorRef& args) { InstSwitch(args); }}, - {Instruction::kTuple, [this](const VectorRef& args) { InstTuple(args); }}, - {Instruction::kPush, [this](const VectorRef& args) { InstPush(args); }}, - {Instruction::kInput, [this](const VectorRef& args) { InstInput(args); }}, - {Instruction::kPadStack, [this](const VectorRef& args) { InstPadStack(args); }}, - {Instruction::kExternal, [this](const VectorRef& args) { InstExternal(args); }}, - {Instruction::kPrim, [this](const VectorRef& args) { InstPushPrim(args); }}, - {Instruction::kSwitchReturn, [this](const VectorRef& args) { InstSwitchReturn(args); }}, + {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, + {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }}, + {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }}, + {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }}, + {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }}, + {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }}, + {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }}, + {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }}, + {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }}, + {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, + {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, + {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, }; }; diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index ee9a817dd8a..017121f334d 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -40,25 +40,25 @@ using PrimitivePyPtr = std::shared_ptr; // Indicate a call to a new frame. struct CallWrap : public Base { - explicit CallWrap(const VMFramePtr& vm_frame) : frame(vm_frame) {} + explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {} VMFramePtr frame{nullptr}; }; using CallWrapPtr = std::shared_ptr; // Indicates a return with its value. struct ReturnWrap : public Base { - explicit ReturnWrap(const BaseRef& r_value) : value(r_value) {} + explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {} BaseRef value{BaseRef()}; }; using ReturnWrapPtr = std::shared_ptr; -VMFrame::VMFrame(const AnfNodePtrList& nodes, const AnfNodePtrToBaseRefMap& values, - const AnfNodePtrToBaseRefMap& closure) +VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, + const AnfNodePtrToBaseRefMap &closure) : values_(values), todo_(nodes), closure_(closure) { std::reverse(std::begin(todo_), std::end(todo_)); } -const BaseRef VMFrame::operator[](const AnfNodePtr& node) { +const BaseRef VMFrame::operator[](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto ret = values_.find(node); if (ret != values_.end()) { @@ -77,31 +77,31 @@ const BaseRef VMFrame::operator[](const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "ValueError " << node->type_name(); } -Closure::Closure(const FuncGraphPtr& graph, const AnfNodePtrToBaseRefMap& values) +Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values) : func_graph_(graph), values_(values) {} -BaseRef Closure::operator()(const VectorRef& args) { +BaseRef Closure::operator()(const VectorRef &args) { MS_LOG(DEBUG) << "start closure"; return vm_->Evaluate(func_graph_, args, values_); } -Partial::Partial(const BaseRef& fn, const VectorRef& args, const VMPtr& vm) : fn_(fn), args_(args), vm_(vm) {} +Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {} -BaseRef Partial::operator()(const VectorRef& nodes) { +BaseRef Partial::operator()(const VectorRef &nodes) { VectorRef arglist; (void)arglist.insert(arglist.end(), args_.begin(), args_.end()); (void)arglist.insert(arglist.end(), nodes.begin(), nodes.end()); return vm_->Call(fn_, arglist); } -SetRef VM::ComputeFvs(const FuncGraphPtr& graph) { +SetRef VM::ComputeFvs(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); SetRef rval; - for (auto& fkv : graph->free_variables_total()) { + for (auto &fkv : graph->free_variables_total()) { if (utils::isa(fkv.first)) { // Add all value_nodes of g that refer to a fv graph auto g = utils::cast(fkv.first); - for (auto& ctkv : g->value_nodes()) { + for (auto &ctkv : g->value_nodes()) { auto ct = ctkv.first; if (GetValueNode(ct) == g) { (void)rval.insert(ct); @@ -116,7 +116,7 @@ SetRef VM::ComputeFvs(const FuncGraphPtr& graph) { return rval; } -void VM::AcquireGraph(const FuncGraphPtr& graph) { +void VM::AcquireGraph(const FuncGraphPtr &graph) { // Already acquired if (vars_.find(graph) != vars_.end()) { return; @@ -130,30 +130,30 @@ void VM::AcquireGraph(const FuncGraphPtr& graph) { } } -VectorRef VM::ExportSequence(const VectorRef& seq) { +VectorRef VM::ExportSequence(const VectorRef &seq) { std::vector ret; (void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret), - [&, this](const BaseRef& x) -> BaseRef { return Export(x); }); + [&, this](const BaseRef &x) -> BaseRef { return Export(x); }); return VectorRef(ret); } -ClosurePtr VM::ExportClosure(const ClosurePtr& clos) { +ClosurePtr VM::ExportClosure(const ClosurePtr &clos) { MS_EXCEPTION_IF_NULL(clos); clos->set_vm(shared_from_this()); return clos; } // transform graph to executable closure -ClosurePtr VM::ExportGraph(const FuncGraphPtr& g) { +ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) { auto c = std::make_shared(g, AnfNodePtrToBaseRefMap()); MS_EXCEPTION_IF_NULL(c); c->set_vm(shared_from_this()); return c; } -BaseRef VM::ExportObj(const BaseRef& obj) const { return obj; } +BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; } -BaseRef VM::Export(const BaseRef& value) { +BaseRef VM::Export(const BaseRef &value) { if (utils::isa(value) && utils::cast(value)->isa()) { return ExportGraph(utils::cast(value)->cast()); } @@ -183,7 +183,7 @@ BaseRef VM::Export(const BaseRef& value) { // Run a graph. // This will evaluate the passed-in graph and return the resulting value. -BaseRef VM::Evaluate(const FuncGraphPtr& graph, const VectorRef& args, const AnfNodePtrToBaseRefMap& closure) { +BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) { AcquireGraph(graph); MS_LOG(DEBUG) << "evalue arg size: " << args.size(); if (args.size() != graph->parameters().size()) { @@ -237,15 +237,15 @@ BaseRef VM::Evaluate(const FuncGraphPtr& graph, const VectorRef& args, const Anf MS_LOG(EXCEPTION) << "VM Evaluate error"; } -SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { - auto fn = [&, this](const AnfNodePtr& node) -> AnfNodePtrList { +SuccFunc VM::SuccVm(const FuncGraphPtr &graph) { + auto fn = [&, this](const AnfNodePtr &node) -> AnfNodePtrList { MS_EXCEPTION_IF_NULL(node); AnfNodePtrList ret; // Follow node.incoming if (node->isa()) { - auto& inputs = node->cast()->inputs(); - for (auto& i : inputs) { + auto &inputs = node->cast()->inputs(); + for (auto &i : inputs) { if (i->func_graph() == node->func_graph() || (IsValueNode(i) && GetValueNode(i)->parent() == graph)) { ret.push_back(i); @@ -257,7 +257,7 @@ SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { if (IsValueNode(node) && GetValueNode(node)->parent() == graph) { auto fvs = utils::cast(vars_[GetValueNode(node)]); (void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(ret), - [](const BaseRef& value) -> AnfNodePtr { return utils::cast(value); }); + [](const BaseRef &value) -> AnfNodePtr { return utils::cast(value); }); } return ret; @@ -265,7 +265,7 @@ SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { return fn; } -BaseRef VM::Call(const BaseRef& fn, const VectorRef& args) { +BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) { if (utils::isa(fn)) { return RunOperation(utils::cast(fn), args); } @@ -283,7 +283,7 @@ BaseRef VM::Call(const BaseRef& fn, const VectorRef& args) { } // make call frame for graph -BaseRef VM::_Call(const BaseRef& graph, const VectorRef& args) { +BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) { AnfNodePtrToBaseRefMap clos; auto func_graph = graph; if (utils::isa(func_graph)) { @@ -319,11 +319,11 @@ BaseRef VM::_Call(const BaseRef& graph, const VectorRef& args) { } // make closure out of graph with fv values from frame -ClosurePtr VM::MakeClosure(const FuncGraphPtr& graph, const VMFramePtr& frame) { +ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) { MS_EXCEPTION_IF_NULL(frame); AnfNodePtrToBaseRefMap clos; - for (auto& v : utils::cast(vars_[graph])) { + for (auto &v : utils::cast(vars_[graph])) { auto anf = utils::cast(v); clos[anf] = (*frame)[anf]; } @@ -331,7 +331,7 @@ ClosurePtr VM::MakeClosure(const FuncGraphPtr& graph, const VMFramePtr& frame) { return std::make_shared(graph, clos); } -BaseRef VM::DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const BaseRef& fn, const VectorRef& args) { +BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) { if (utils::isa(fn) && utils::cast(fn)->isa()) { auto fnval = utils::cast(fn)->cast(); MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true); @@ -384,7 +384,7 @@ BaseRef VM::DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const MS_LOG(EXCEPTION) << "Invalid fn to call"; } -BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { +BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { // pass @@ -409,10 +409,10 @@ BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { if (node->isa()) { std::vector fnArgs; - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); // set args' values in frame (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs), - [&](const AnfNodePtr& inp) -> BaseRef { return (*frame)[inp]; }); + [&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; }); if (fnArgs.empty()) { MS_LOG(EXCEPTION) << "function arguments is empty"; } else { @@ -425,7 +425,7 @@ BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { MS_LOG(EXCEPTION) << "Unknown node type"; } -VectorRef VM::RunGraph(const FuncGraphPtr& g, const VectorRef& args) { +VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) { this->manager_ = Manage(g); auto fn = utils::cast(Export(g)); @@ -439,7 +439,7 @@ VectorRef VM::RunGraph(const FuncGraphPtr& g, const VectorRef& args) { } } -BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args) { +BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { PrimitivePyPtr operation = dyn_cast(prim); MS_LOG(DEBUG) << "operation start " << prim->name(); @@ -451,7 +451,7 @@ BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args) { py::tuple py_args = py::tuple(args.size()); MS_LOG(DEBUG) << "input for operation:"; size_t i = 0; - for (auto& arg : args) { + for (auto &arg : args) { py_args[i] = BaseRefToPyData(arg); MS_LOG(DEBUG) << "arg: " << i << ":"; i++; diff --git a/mindspore/ccsrc/vm/vmimpl.h b/mindspore/ccsrc/vm/vmimpl.h index 4ef507af826..11d026fe721 100644 --- a/mindspore/ccsrc/vm/vmimpl.h +++ b/mindspore/ccsrc/vm/vmimpl.h @@ -53,14 +53,14 @@ using VMPtr = std::shared_ptr; class Partial; using PartialPtr = std::shared_ptr; -using RunFunc = std::function; +using RunFunc = std::function; using RunFuncPtr = std::shared_ptr; using SuccFunc = std::function; class VMImpl { public: - virtual VectorRef RunGraph(const FuncGraphPtr& fg, const VectorRef& args) = 0; + virtual VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) = 0; virtual ~VMImpl() = default; }; @@ -76,11 +76,11 @@ class VMImpl { // closure: values for the closure if the current application is a closure class VMFrame { public: - VMFrame(const AnfNodePtrList& nodes, const AnfNodePtrToBaseRefMap& values, const AnfNodePtrToBaseRefMap& closure); - const BaseRef operator[](const AnfNodePtr& node); - const AnfNodePtrList& todo() const { return todo_; } + VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure); + const BaseRef operator[](const AnfNodePtr &node); + const AnfNodePtrList &todo() const { return todo_; } - AnfNodePtrToBaseRefMap& values() { return values_; } + AnfNodePtrToBaseRefMap &values() { return values_; } virtual ~VMFrame() = default; @@ -94,16 +94,16 @@ class VMFrame { // Representation of a closure. class Closure : public Base { public: - Closure(const FuncGraphPtr& func_graph, const AnfNodePtrToBaseRefMap& values); - BaseRef operator()(const VectorRef& args); + Closure(const FuncGraphPtr &func_graph, const AnfNodePtrToBaseRefMap &values); + BaseRef operator()(const VectorRef &args); - const VMPtr& vm() const { return vm_; } + const VMPtr &vm() const { return vm_; } - void set_vm(const VMPtr& vm) { vm_ = vm; } + void set_vm(const VMPtr &vm) { vm_ = vm; } - const FuncGraphPtr& func_graph() const { return func_graph_; } + const FuncGraphPtr &func_graph() const { return func_graph_; } - const AnfNodePtrToBaseRefMap& values() const { return values_; } + const AnfNodePtrToBaseRefMap &values() const { return values_; } virtual ~Closure() = default; @@ -118,11 +118,11 @@ class Closure : public Base { // Representation of a partial application. class Partial : public Base { public: - Partial(const BaseRef& fn, const VectorRef& args, const VMPtr& vm); - BaseRef operator()(const VectorRef& nodes); - const BaseRef& fn() const { return fn_; } + Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm); + BaseRef operator()(const VectorRef &nodes); + const BaseRef &fn() const { return fn_; } - const VectorRef& args() const { return args_; } + const VectorRef &args() const { return args_; } virtual ~Partial() = default; MS_DECLARE_PARENT(Partial, Base) @@ -136,52 +136,52 @@ class Partial : public Base { // Virtual Machine interface. class VM : public std::enable_shared_from_this, public VMImpl { public: - SetRef ComputeFvs(const FuncGraphPtr& func_graph); + SetRef ComputeFvs(const FuncGraphPtr &func_graph); - void AcquireGraph(const FuncGraphPtr& func_graph); + void AcquireGraph(const FuncGraphPtr &func_graph); - VectorRef ExportSequence(const VectorRef& seq); + VectorRef ExportSequence(const VectorRef &seq); - BaseRef ExportPrimitive(const PrimitivePtr&) const { return kAnyValue; } + BaseRef ExportPrimitive(const PrimitivePtr &) const { return kAnyValue; } - ClosurePtr ExportClosure(const ClosurePtr& clos); + ClosurePtr ExportClosure(const ClosurePtr &clos); // Return an object that executes `fg` when called on arguments. - ClosurePtr ExportGraph(const FuncGraphPtr& fg); + ClosurePtr ExportGraph(const FuncGraphPtr &fg); - BaseRef ExportObj(const BaseRef& obj) const; + BaseRef ExportObj(const BaseRef &obj) const; - BaseRef Export(const BaseRef& value); + BaseRef Export(const BaseRef &value); // Run a graph. // This will evaluate the passed-in graph and return the // resulting value. - BaseRef Evaluate(const FuncGraphPtr& func_graph, const VectorRef& args, - const AnfNodePtrToBaseRefMap& closure = AnfNodePtrToBaseRefMap()); + BaseRef Evaluate(const FuncGraphPtr &func_graph, const VectorRef &args, + const AnfNodePtrToBaseRefMap &closure = AnfNodePtrToBaseRefMap()); // Return a visitor for the graph. - SuccFunc SuccVm(const FuncGraphPtr& func_graph); + SuccFunc SuccVm(const FuncGraphPtr &func_graph); // Call the `fn` object. // `fn` can be anything that would be valid as the first element of an apply. - BaseRef Call(const BaseRef& fn, const VectorRef& args); + BaseRef Call(const BaseRef &fn, const VectorRef &args); - BaseRef _Call(const BaseRef& graph, const VectorRef& args); + BaseRef _Call(const BaseRef &graph, const VectorRef &args); - ClosurePtr MakeClosure(const FuncGraphPtr& func_graph, const VMFramePtr& frame); + ClosurePtr MakeClosure(const FuncGraphPtr &func_graph, const VMFramePtr &frame); - BaseRef DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const BaseRef& fn, const VectorRef& args); + BaseRef DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args); - BaseRef HandleNode(const AnfNodePtr& node, const VMFramePtr& frame); + BaseRef HandleNode(const AnfNodePtr &node, const VMFramePtr &frame); - VectorRef RunGraph(const FuncGraphPtr& fg, const VectorRef& args) override; + VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) override; private: FuncGraphManagerPtr manager_; FuncGraphPtrToBaseRefMap vars_; }; -extern BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args); +extern BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args); } // namespace compile } // namespace mindspore