update clang format rule

This commit is contained in:
zhoufeng 2020-04-21 21:21:19 +08:00
parent 31a12009dd
commit c2b3360d69
278 changed files with 4447 additions and 4441 deletions

View File

@ -94,7 +94,7 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10 PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000 PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200 PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left PointerAlignment: Right
RawStringFormats: RawStringFormats:
- Language: Cpp - Language: Cpp
Delimiters: Delimiters:

View File

@ -23,7 +23,7 @@ namespace common {
const int CACHED_STR_NUM = 1 << 8; const int CACHED_STR_NUM = 1 << 8;
const int CACHED_STR_MASK = CACHED_STR_NUM - 1; const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM); std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
const char* SafeCStr(const std::string&& str) { const char *SafeCStr(const std::string &&str) {
static std::atomic<uint32_t> index{0}; static std::atomic<uint32_t> index{0};
uint32_t cur_index = index++; uint32_t cur_index = index++;
cur_index = cur_index & CACHED_STR_MASK; cur_index = cur_index & CACHED_STR_MASK;

View File

@ -21,16 +21,16 @@
#include <string> #include <string>
#define DISABLE_COPY_AND_ASSIGN(ClassType) \ #define DISABLE_COPY_AND_ASSIGN(ClassType) \
ClassType(const ClassType&) = delete; \ ClassType(const ClassType &) = delete; \
ClassType& operator=(const ClassType&) = delete; ClassType &operator=(const ClassType &) = delete;
namespace mindspore { namespace mindspore {
namespace common { namespace common {
inline const char* SafeCStr(const std::string& str) { return str.c_str(); } inline const char *SafeCStr(const std::string &str) { return str.c_str(); }
const char* SafeCStr(const std::string&& str); const char *SafeCStr(const std::string &&str);
static inline std::string GetEnv(const std::string& envvar) { static inline std::string GetEnv(const std::string &envvar) {
const char* value = ::getenv(envvar.c_str()); const char *value = ::getenv(envvar.c_str());
if (value == nullptr) { if (value == nullptr) {
return std::string(); return std::string();

View File

@ -34,11 +34,11 @@ class DecodeOp : public TensorOp {
~DecodeOp() = default; ~DecodeOp() = default;
Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
void Print(std::ostream& out) const override { out << "DecodeOp"; } void Print(std::ostream &out) const override { out << "DecodeOp"; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private: private:
bool is_rgb_format_ = true; bool is_rgb_format_ = true;

View File

@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int
rnd_.seed(seed_); rnd_.seed(seed_);
} }
Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input, Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>>* output) { std::vector<std::shared_ptr<Tensor>> *output) {
IO_CHECK_VECTOR(input, output); IO_CHECK_VECTOR(input, output);
if (input.size() != NumInput()) if (input.size() != NumInput())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5");
@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tenso
return Status::OK(); return Status::OK();
} }
Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs, Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape> &inputs,
std::vector<TensorShape>& outputs) { std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear(); outputs.clear();
TensorShape out = TensorShape{-1, -1}; TensorShape out = TensorShape{-1, -1};
@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inp
if (!outputs.empty()) return Status::OK(); if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
} }
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) { Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
outputs[0] = inputs[0]; outputs[0] = inputs[0];
return Status::OK(); return Status::OK();

View File

@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp {
~DistortBoundingBoxCropOp() override = default; ~DistortBoundingBoxCropOp() override = default;
void Print(std::ostream& out) const override { void Print(std::ostream &out) const override {
out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_;
} }
Status Compute(const std::vector<std::shared_ptr<Tensor>>& input, Status Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>>* output) override; std::vector<std::shared_ptr<Tensor>> *output) override;
uint32_t NumInput() override { return 5; } uint32_t NumInput() override { return 5; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private: private:
int32_t max_attempts_; int32_t max_attempts_;

View File

@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
rnd_.seed(GetSeed()); rnd_.seed(GetSeed());
} }
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) { Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output); IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal");
@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std:
(void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width);
return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_);
} }
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) { Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear(); outputs.clear();
TensorShape out = TensorShape{target_height_, target_width_}; TensorShape out = TensorShape{target_height_, target_width_};
@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs
if (!outputs.empty()) return Status::OK(); if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
} }
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) { Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) {
double scale, aspect; double scale, aspect;
*crop_width = w_in; *crop_width = w_in;
*crop_height = h_in; *crop_height = h_in;

View File

@ -22,7 +22,7 @@
namespace mindspore { namespace mindspore {
constexpr char PARALLEL_STRATEGY[] = "strategy"; constexpr char PARALLEL_STRATEGY[] = "strategy";
void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false); void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false);
} // namespace mindspore } // namespace mindspore

View File

@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF;
// get MindSpore Intermediate Representation Path // get MindSpore Intermediate Representation Path
std::string GetMsIrPath(void) { std::string GetMsIrPath(void) {
std::string path; std::string path;
const char* path_ptr = getenv("MS_IR_PATH"); const char *path_ptr = getenv("MS_IR_PATH");
if (path_ptr != nullptr) { if (path_ptr != nullptr) {
path = path_ptr; path = path_ptr;
char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
@ -62,13 +62,13 @@ std::string GetMsIrPath(void) {
return path; return path;
} }
std::string dump_obj(const py::object& obj, const std::string& path) { std::string dump_obj(const py::object &obj, const std::string &path) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path));
return py::str(name); return py::str(name);
} }
py::object load_obj(const std::string& path) { py::object load_obj(const std::string &path) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path));
return obj; return obj;
@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) {
// ============================================= MindSpore IR Exporter ============================================= // ============================================= MindSpore IR Exporter =============================================
std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) {
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
TypePtr type = dyn_cast<Type>(nd->Type()); TypePtr type = dyn_cast<Type>(nd->Type());
std::ostringstream oss; std::ostringstream oss;
@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
return oss.str(); return oss.str();
} }
std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const { std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const {
std::string pkl_path = GetMsIrPath(); std::string pkl_path = GetMsIrPath();
// if not specified env 'MS_IR_PATH', do not create any files // if not specified env 'MS_IR_PATH', do not create any files
if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) {
@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca
return file_prefix + file_name; return file_prefix + file_name;
} }
int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) { int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp) {
if (func_graph == nullptr || param == nullptr) { if (func_graph == nullptr || param == nullptr) {
return -1; return -1;
} }
@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&
// try to find index of parameter for SymbolicKeyInstance from all exported graphs // try to find index of parameter for SymbolicKeyInstance from all exported graphs
// NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different
int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { int AnfExporter::GetParamIndexFromExported(const AnfNodePtr &param) {
if (param == nullptr) { if (param == nullptr) {
return -1; return -1;
} }
int ret = -1; int ret = -1;
for (const auto& item : exported) { for (const auto &item : exported) {
auto pram_iter = item.second.find(param); auto pram_iter = item.second.find(param);
if (pram_iter != item.second.end()) { if (pram_iter != item.second.end()) {
return pram_iter->second; return pram_iter->second;
@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) {
return ret; return ret;
} }
std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) { std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
return GetValueText(fg, node->value()); return GetValueText(fg, node->value());
} }
std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) { std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) {
auto py_funcs = mt_func_graph->GetPyFunctions(); auto py_funcs = mt_func_graph->GetPyFunctions();
if (py_funcs.empty()) { if (py_funcs.empty()) {
return ""; return "";
@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
oss << "{"; oss << "{";
bool is_first = true; bool is_first = true;
for (const auto& py_func : py_funcs) { for (const auto &py_func : py_funcs) {
if (is_first) { if (is_first) {
is_first = false; is_first = false;
} else { } else {
@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
* GradOperation * GradOperation
* TupleAdd * TupleAdd
*/ */
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) { std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
if (meta_func_graph == nullptr) { if (meta_func_graph == nullptr) {
return ""; return "";
} }
@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_
return oss.str(); return oss.str();
} }
std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) {
std::ostringstream oss; std::ostringstream oss;
if (prim == nullptr) { if (prim == nullptr) {
return oss.str(); return oss.str();
@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
if (prim->isa<prim::DoSignaturePrimitive>()) { if (prim->isa<prim::DoSignaturePrimitive>()) {
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim); auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
auto& func = do_signature->function(); auto &func = do_signature->function();
if (func->isa<Primitive>()) { if (func->isa<Primitive>()) {
auto sig_prim = dyn_cast<Primitive>(func); auto sig_prim = dyn_cast<Primitive>(func);
oss << sig_prim->GetAttrsText(); oss << sig_prim->GetAttrsText();
@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
return oss.str(); return oss.str();
} }
std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) {
std::ostringstream oss; std::ostringstream oss;
if (ns == nullptr) { if (ns == nullptr) {
return oss.str(); return oss.str();
@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) {
return oss.str(); return oss.str();
} }
std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph,
const SymbolicKeyInstancePtr& sym_inst) { const SymbolicKeyInstancePtr &sym_inst) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(sym_inst); MS_EXCEPTION_IF_NULL(sym_inst);
AnfNodePtr sym_node = sym_inst->node(); AnfNodePtr sym_node = sym_inst->node();
@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra
return oss.str(); return oss.str();
} }
std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) { std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss; std::ostringstream oss;
// output ValueList, ValueTuple // output ValueList, ValueTuple
ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value); ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value);
@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V
return oss.str(); return oss.str();
} }
std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) { std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss; std::ostringstream oss;
ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>(); ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>();
oss << "{"; oss << "{";
bool first_flag = true; bool first_flag = true;
for (const auto& elem : dict->value()) { for (const auto &elem : dict->value()) {
if (first_flag) { if (first_flag) {
first_flag = false; first_flag = false;
} else { } else {
@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
return oss.str(); return oss.str();
} }
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) {
std::ostringstream oss; std::ostringstream oss;
if (check_integrity_) { if (check_integrity_) {
@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr&
return oss.str(); return oss.str();
} }
std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) { std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss; std::ostringstream oss;
bool is_null_ptr = (func_graph == nullptr || value == nullptr); bool is_null_ptr = (func_graph == nullptr || value == nullptr);
if (is_null_ptr) { if (is_null_ptr) {
@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu
} }
// this function is used to output node in CNode's inputs // this function is used to output node in CNode's inputs
std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, int>& apply_map) { const std::map<AnfNodePtr, int> &apply_map) {
std::ostringstream oss; std::ostringstream oss;
if (func_graph == nullptr || node == nullptr) { if (func_graph == nullptr || node == nullptr) {
return oss.str(); return oss.str();
@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
return oss.str(); return oss.str();
} }
void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map) { OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) {
bool first_flag = true; bool first_flag = true;
for (const AnfNodePtr& param : parameters) { for (const AnfNodePtr &param : parameters) {
if (first_flag) { if (first_flag) {
first_flag = false; first_flag = false;
ofs << " "; ofs << " ";
@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNode
} }
} }
void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& node) { void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return; return;
} }
// output type of each input argument // output type of each input argument
auto& inputs = node->inputs(); auto &inputs = node->inputs();
if (inputs.size() > 1) { if (inputs.size() > 1) {
ofs << " #("; ofs << " #(";
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod
ofs << " #scope: " << node->scope()->name(); ofs << " #scope: " << node->scope()->name();
} }
void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
const FuncGraphPtr& func_graph) { const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return; return;
} }
int idx = 1; int idx = 1;
std::map<AnfNodePtr, int> apply_map; std::map<AnfNodePtr, int> apply_map;
for (const AnfNodePtr& node : nodes) { for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>
} }
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
auto& inputs = cnode->inputs(); auto &inputs = cnode->inputs();
std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
// non-return node // non-return node
if (node != func_graph->get_return()) { if (node != func_graph->get_return()) {
@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>
} }
} }
void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) { void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return; return;
} }
@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun
ofs << "}\n"; ofs << "}\n";
} }
void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return; return;
} }
@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt
ofs.close(); ofs.close();
} }
void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs) { void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
if (graphs.empty()) { if (graphs.empty()) {
return; return;
} }
@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector
param_index = 1; param_index = 1;
for (const auto& tagged_graph : graphs) { for (const auto &tagged_graph : graphs) {
tagged_cnodes_ = tagged_graph.second; tagged_cnodes_ = tagged_graph.second;
ExportOneFuncGraph(ofs, tagged_graph.first); ExportOneFuncGraph(ofs, tagged_graph.first);
tagged_cnodes_.clear(); tagged_cnodes_.clear();
@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector
} }
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) { void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return; return;
} }
@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap
ChangeFileMode(filename, S_IRUSR); ChangeFileMode(filename, S_IRUSR);
} }
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) { void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
AnfExporter exporter("", false); AnfExporter exporter("", false);
ChangeFileMode(filename, S_IRWXU); ChangeFileMode(filename, S_IRWXU);
exporter.ExportFuncGraph(filename, graphs); exporter.ExportFuncGraph(filename, graphs);
@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graph
ChangeFileMode(filename, S_IRUSR); ChangeFileMode(filename, S_IRUSR);
} }
#else #else
void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) {
static bool already_printed = false; static bool already_printed = false;
if (already_printed) { if (already_printed) {
return; return;
@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) {
<< "please recompile source to enable it. See help of building script."; << "please recompile source to enable it. See help of building script.";
} }
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) { void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
static bool already_printed = false; static bool already_printed = false;
if (already_printed) { if (already_printed) {
return; return;
@ -732,7 +732,7 @@ enum Token : int {
TOK_ERROR // file read error TOK_ERROR // file read error
}; };
std::map<Token, const char*> token_text = { std::map<Token, const char *> token_text = {
{TOK_INVALID, "invalid"}, // invalid token {TOK_INVALID, "invalid"}, // invalid token
{TOK_LPARENTHESIS, "("}, // ( left parenthesis {TOK_LPARENTHESIS, "("}, // ( left parenthesis
{TOK_RPARENTHESIS, ")"}, // ) right parenthesis {TOK_RPARENTHESIS, ")"}, // ) right parenthesis
@ -761,14 +761,14 @@ std::map<Token, const char*> token_text = {
class Lexer { class Lexer {
public: public:
// filename is checked in ImportIR; // filename is checked in ImportIR;
explicit Lexer(const char* filename) : fin(filename) {} explicit Lexer(const char *filename) : fin(filename) {}
~Lexer() { ~Lexer() {
try { try {
if (fin.is_open()) { if (fin.is_open()) {
fin.close(); fin.close();
} }
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file"; MS_LOG(ERROR) << "Exception when closing file";
} catch (...) { } catch (...) {
std::string exName(abi::__cxa_current_exception_type()->name()); std::string exName(abi::__cxa_current_exception_type()->name());
@ -776,7 +776,7 @@ class Lexer {
} }
} }
bool IsSingleCharToken(char ch, Token* token_ptr) { bool IsSingleCharToken(char ch, Token *token_ptr) {
// clang-format off // clang-format off
std::unordered_map<char, Token> char_to_token = { std::unordered_map<char, Token> char_to_token = {
{'(', TOK_LPARENTHESIS}, {'(', TOK_LPARENTHESIS},
@ -806,7 +806,7 @@ class Lexer {
Token GetNextToken() { Token GetNextToken() {
#ifdef DEBUG #ifdef DEBUG
Token token = GetNextTokenInner(); Token token = GetNextTokenInner();
const char* str = token_text[token]; const char *str = token_text[token];
std::string text = (str == nullptr ? GetTokenText() : str); std::string text = (str == nullptr ? GetTokenText() : str);
MS_LOG(DEBUG) << "------Parse token] " << text; MS_LOG(DEBUG) << "------Parse token] " << text;
return token; return token;
@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE;
class IrParser { class IrParser {
public: public:
explicit IrParser(const char* filename) : lexer_(filename) {} explicit IrParser(const char *filename) : lexer_(filename) {}
~IrParser() {} ~IrParser() {}
py::object LoadObject(const std::string& file_name) const { py::object LoadObject(const std::string &file_name) const {
std::string pkl_path = GetMsIrPath(); std::string pkl_path = GetMsIrPath();
py::object default_obj = load_obj(pkl_path + "/" + file_name); py::object default_obj = load_obj(pkl_path + "/" + file_name);
return default_obj; return default_obj;
@ -1087,7 +1087,7 @@ class IrParser {
MS_LOG(INFO) << "Total graphs: " << func_graphs_.size(); MS_LOG(INFO) << "Total graphs: " << func_graphs_.size();
} }
Token ParseParent(FuncGraphPtr* const parent_ptr) { Token ParseParent(FuncGraphPtr *const parent_ptr) {
if (lexer_.GetNextToken() != TOK_IDENTIFIER) { if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
return TOK_ERROR; return TOK_ERROR;
} }
@ -1168,7 +1168,7 @@ class IrParser {
return func_graph; return func_graph;
} }
FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) { FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) {
Token tok = lexer_.SkipWhiteToken(); Token tok = lexer_.SkipWhiteToken();
while (tok == TOK_VARIABLE) { while (tok == TOK_VARIABLE) {
if (ParseStatement(func_graph) == nullptr) { if (ParseStatement(func_graph) == nullptr) {
@ -1264,56 +1264,56 @@ class IrParser {
return func_graph; return func_graph;
} }
void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const { void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = dtype; *ptr = dtype;
} }
void SetTupleType(TypePtr* ptr) { void SetTupleType(TypePtr *ptr) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = std::make_shared<Tuple>(); *ptr = std::make_shared<Tuple>();
} }
void SetTupleType(TypePtr* ptr, const TypePtrList& elems) { void SetTupleType(TypePtr *ptr, const TypePtrList &elems) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = std::make_shared<Tuple>(elems); *ptr = std::make_shared<Tuple>(elems);
} }
void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector<int>&) { void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = std::make_shared<TensorType>(elem_type); *ptr = std::make_shared<TensorType>(elem_type);
} }
void SetListType(TypePtr* ptr) { void SetListType(TypePtr *ptr) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = std::make_shared<List>(); *ptr = std::make_shared<List>();
} }
void SetListType(TypePtr* ptr, const TypePtrList& elems) { void SetListType(TypePtr *ptr, const TypePtrList &elems) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = std::make_shared<List>(elems); *ptr = std::make_shared<List>(elems);
} }
void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) { void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = std::make_shared<JTagged>(elem); *ptr = std::make_shared<JTagged>(elem);
} }
void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const { void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
@ -1321,45 +1321,45 @@ class IrParser {
} }
// void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {} // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {}
void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const { void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = std::make_shared<abstract::AbstractNone>(); *ptr = std::make_shared<abstract::AbstractNone>();
} }
void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {} void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {}
void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {} void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {}
void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
// if one of elems is nullptr, just return // if one of elems is nullptr, just return
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
return; return;
} }
*ptr = std::make_shared<abstract::AbstractTuple>(elems); *ptr = std::make_shared<abstract::AbstractTuple>(elems);
} }
void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector<int>& shape) { void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
*ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape); *ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape);
} }
void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
return; return;
} }
*ptr = std::make_shared<abstract::AbstractList>(elems); *ptr = std::make_shared<abstract::AbstractList>(elems);
} }
void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) { void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) {
if (ptr == nullptr) { if (ptr == nullptr) {
return; return;
} }
@ -1367,7 +1367,7 @@ class IrParser {
} }
template <typename T> template <typename T>
Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) { Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) {
if (tok != TOK_LBRACKET) { if (tok != TOK_LBRACKET) {
MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol."; MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol.";
return tok; return tok;
@ -1415,7 +1415,7 @@ class IrParser {
} }
template <typename T> template <typename T>
Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
if (tok != TOK_LPARENTHESIS) { if (tok != TOK_LPARENTHESIS) {
if (ptr != nullptr) { if (ptr != nullptr) {
SetBasicType(ptr, std::make_shared<TensorType>()); SetBasicType(ptr, std::make_shared<TensorType>());
@ -1454,7 +1454,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
bool IsNumberType(const std::string& type, TypeId* typeid_ptr) { bool IsNumberType(const std::string &type, TypeId *typeid_ptr) {
// clang-format off // clang-format off
static std::unordered_map<std::string, TypeId> basic_types = { static std::unordered_map<std::string, TypeId> basic_types = {
{"Bool", kNumberTypeBool}, {"Bool", kNumberTypeBool},
@ -1486,7 +1486,7 @@ class IrParser {
} }
template <typename T> template <typename T>
void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) { void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) {
TypePtr dtype = nullptr; TypePtr dtype = nullptr;
std::unordered_map<int, TypePtr> type_map = { std::unordered_map<int, TypePtr> type_map = {
@ -1519,7 +1519,7 @@ class IrParser {
} }
template <typename T> template <typename T>
Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) { Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) {
if (type == "NoneType") { if (type == "NoneType") {
SetBasicType(ptr, std::make_shared<TypeNone>()); SetBasicType(ptr, std::make_shared<TypeNone>());
return lexer_.GetNextToken(); return lexer_.GetNextToken();
@ -1541,7 +1541,7 @@ class IrParser {
} }
template <typename T> template <typename T>
Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
if (tok != TOK_IDENTIFIER) { if (tok != TOK_IDENTIFIER) {
return TOK_ERROR; return TOK_ERROR;
} }
@ -1588,11 +1588,11 @@ class IrParser {
} }
} }
Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) { Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) {
return ParseOneType(func_graph, lexer_.GetNextToken(), abstract); return ParseOneType(func_graph, lexer_.GetNextToken(), abstract);
} }
Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
Token tok = ParseAttribute(func_graph, prim); Token tok = ParseAttribute(func_graph, prim);
while (tok == TOK_COMMA) { while (tok == TOK_COMMA) {
tok = ParseAttribute(func_graph, prim); tok = ParseAttribute(func_graph, prim);
@ -1603,7 +1603,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
Token tok = lexer_.GetNextToken(); Token tok = lexer_.GetNextToken();
if (tok != TOK_IDENTIFIER) { if (tok != TOK_IDENTIFIER) {
return TOK_ERROR; return TOK_ERROR;
@ -1670,7 +1670,7 @@ class IrParser {
return tok == TOK_RPARENTHESIS ? func_graph : nullptr; return tok == TOK_RPARENTHESIS ? func_graph : nullptr;
} }
FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr>* const inputs_ptr) { FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
Token tok = ParseArgument(func_graph, inputs_ptr); Token tok = ParseArgument(func_graph, inputs_ptr);
while (tok == TOK_COMMA) { while (tok == TOK_COMMA) {
tok = ParseArgument(func_graph, inputs_ptr); tok = ParseArgument(func_graph, inputs_ptr);
@ -1681,9 +1681,9 @@ class IrParser {
return func_graph; return func_graph;
} }
AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) { AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string &param_name) {
while (func_graph != nullptr) { while (func_graph != nullptr) {
for (auto& ptr : func_graph->parameters()) { for (auto &ptr : func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
ParameterPtr param = ptr->cast<ParameterPtr>(); ParameterPtr param = ptr->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param); MS_EXCEPTION_IF_NULL(param);
@ -1701,12 +1701,12 @@ class IrParser {
return nullptr; return nullptr;
} }
bool Match(const std::string& str, const std::string& pattern) const { bool Match(const std::string &str, const std::string &pattern) const {
return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0; return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0;
} }
template <typename T, typename V> template <typename T, typename V>
Token ParseScalar(ValuePtr* const val_ptr) { Token ParseScalar(ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_NUMBER) { if (lexer_.GetNextToken() != TOK_NUMBER) {
return TOK_ERROR; return TOK_ERROR;
} }
@ -1725,7 +1725,7 @@ class IrParser {
} }
template <typename VT, typename V, typename T> template <typename VT, typename V, typename T>
Token ParseScalar(ValuePtr* const val_ptr, Token tok) { Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
if (tok != TOK_LPARENTHESIS) { if (tok != TOK_LPARENTHESIS) {
*val_ptr = std::make_shared<T>(); *val_ptr = std::make_shared<T>();
return tok; return tok;
@ -1735,7 +1735,7 @@ class IrParser {
} }
template <typename VT, typename V, typename T, const unsigned nbits> template <typename VT, typename V, typename T, const unsigned nbits>
Token ParseScalar(ValuePtr* const val_ptr, Token tok) { Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
if (tok != TOK_LPARENTHESIS) { if (tok != TOK_LPARENTHESIS) {
*val_ptr = std::make_shared<T>(nbits); *val_ptr = std::make_shared<T>(nbits);
return tok; return tok;
@ -1745,7 +1745,7 @@ class IrParser {
} }
template <typename T> template <typename T>
T StringToScalar(const std::string& text) { T StringToScalar(const std::string &text) {
std::stringstream ss; std::stringstream ss;
T value; T value;
ss << text; ss << text;
@ -1753,7 +1753,7 @@ class IrParser {
return value; return value;
} }
Token ParseTensor(ValuePtr* const val_ptr) { Token ParseTensor(ValuePtr *const val_ptr) {
// parse type // parse type
TypeId type; TypeId type;
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
@ -1803,7 +1803,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParsePrimType(Token tok, PrimType* prim_type_ptr) { Token ParsePrimType(Token tok, PrimType *prim_type_ptr) {
if (tok != TOK_LBRACE) { if (tok != TOK_LBRACE) {
return tok; return tok;
} }
@ -1830,7 +1830,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
if (tok != TOK_LPARENTHESIS) { if (tok != TOK_LPARENTHESIS) {
return TOK_ERROR; return TOK_ERROR;
} }
@ -1855,7 +1855,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
if (tok != TOK_LBRACE) { if (tok != TOK_LBRACE) {
return tok; return tok;
} }
@ -1868,7 +1868,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParseBoolValue(const std::string& key, bool* val_ptr) { Token ParseBoolValue(const std::string &key, bool *val_ptr) {
if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) { if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) {
return TOK_ERROR; return TOK_ERROR;
} }
@ -1892,7 +1892,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) { Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_LBRACE) { if (lexer_.GetNextToken() != TOK_LBRACE) {
return TOK_ERROR; return TOK_ERROR;
} }
@ -1920,7 +1920,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) { Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) {
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
return TOK_ERROR; return TOK_ERROR;
} }
@ -1951,7 +1951,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) { Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_AT_FILE) { if (lexer_.GetNextToken() != TOK_AT_FILE) {
return TOK_ERROR; return TOK_ERROR;
} }
@ -1984,7 +1984,7 @@ class IrParser {
return next; return next;
} }
Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) { Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) {
if (Match(id, "MultitypeFuncGraph::")) { if (Match(id, "MultitypeFuncGraph::")) {
std::string name = id.substr(strlen("MultitypeFuncGraph::")); std::string name = id.substr(strlen("MultitypeFuncGraph::"));
auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name); auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name);
@ -2024,8 +2024,8 @@ class IrParser {
} }
} }
Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr, Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr,
AnfNodePtr* const node_ptr = nullptr) { AnfNodePtr *const node_ptr = nullptr) {
if (id == "None") { if (id == "None") {
*val_ptr = std::make_shared<None>(); *val_ptr = std::make_shared<None>();
return lexer_.GetNextToken(); return lexer_.GetNextToken();
@ -2075,9 +2075,9 @@ class IrParser {
} }
} }
Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid, Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid,
const std::vector<ValuePtr>& elems, const std::vector<AnfNodePtr>& nodes, const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes,
ValuePtr* const val_ptr, AnfNodePtr* node_ptr) { ValuePtr *const val_ptr, AnfNodePtr *node_ptr) {
if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) { if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) {
if (node_is_valid && node_ptr != nullptr) { if (node_is_valid && node_ptr != nullptr) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -2097,8 +2097,8 @@ class IrParser {
} }
} }
Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr,
AnfNodePtr* node_ptr = nullptr) { AnfNodePtr *node_ptr = nullptr) {
Token left_tok = tok; Token left_tok = tok;
std::vector<ValuePtr> elems; std::vector<ValuePtr> elems;
@ -2138,7 +2138,7 @@ class IrParser {
return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr); return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr);
} }
Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) { Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) {
// tuple or list // tuple or list
if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) { if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) {
return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr); return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr);
@ -2152,7 +2152,7 @@ class IrParser {
return TOK_ERROR; return TOK_ERROR;
} }
Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr, Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr,
Token tok = TOK_INVALID) { Token tok = TOK_INVALID) {
if (tok == TOK_INVALID) { if (tok == TOK_INVALID) {
tok = lexer_.GetNextToken(); tok = lexer_.GetNextToken();
@ -2193,7 +2193,7 @@ class IrParser {
return lexer_.GetNextToken(); return lexer_.GetNextToken();
} }
Token ParseArgument(const FuncGraphPtr& func_graph, std::vector<AnfNodePtr>* const inputs_ptr) { Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
Token tok = lexer_.GetNextToken(); Token tok = lexer_.GetNextToken();
if (tok == TOK_RPARENTHESIS) { if (tok == TOK_RPARENTHESIS) {
return tok; return tok;
@ -2208,7 +2208,7 @@ class IrParser {
return tok; return tok;
} }
const std::vector<FuncGraphPtr>& GetFuncGraphs() const { return func_graphs_; } const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; }
private: private:
Lexer lexer_; Lexer lexer_;
@ -2226,14 +2226,14 @@ class IrParser {
std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter
}; };
std::vector<FuncGraphPtr> ImportIR(const std::string& filename) { std::vector<FuncGraphPtr> ImportIR(const std::string &filename) {
IrParser parser(filename.c_str()); IrParser parser(filename.c_str());
parser.ParseFile(); parser.ParseFile();
return parser.GetFuncGraphs(); return parser.GetFuncGraphs();
} }
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(ERROR) << "Func graph is nullptr"; MS_LOG(ERROR) << "Func graph is nullptr";
return; return;
@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
return; return;
} }
char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
char* real_path_ret = nullptr; char *real_path_ret = nullptr;
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX);
#else #else
@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
ChangeFileMode(file_path, S_IRUSR); ChangeFileMode(file_path, S_IRUSR);
} }
#else #else
void DumpIRProto(const FuncGraphPtr&, const std::string&) { void DumpIRProto(const FuncGraphPtr &, const std::string &) {
static bool already_printed = false; static bool already_printed = false;
if (already_printed) { if (already_printed) {
return; return;

View File

@ -39,7 +39,7 @@
namespace mindspore { namespace mindspore {
struct ParamPtrEqual { struct ParamPtrEqual {
bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const { bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const {
const ParameterPtr param1 = dyn_cast<Parameter>(t1); const ParameterPtr param1 = dyn_cast<Parameter>(t1);
const ParameterPtr param2 = dyn_cast<Parameter>(t2); const ParameterPtr param2 = dyn_cast<Parameter>(t2);
@ -52,7 +52,7 @@ struct ParamPtrEqual {
}; };
struct ParamPtrHasher { struct ParamPtrHasher {
std::size_t operator()(AnfNodePtr const& param) const { std::size_t operator()(AnfNodePtr const &param) const {
const ParameterPtr parameter = dyn_cast<Parameter>(param); const ParameterPtr parameter = dyn_cast<Parameter>(param);
if (parameter == nullptr) { if (parameter == nullptr) {
return 0; return 0;
@ -64,39 +64,39 @@ struct ParamPtrHasher {
class AnfExporter { class AnfExporter {
public: public:
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false)
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
func_graph_set.clear(); func_graph_set.clear();
exported.clear(); exported.clear();
} }
virtual ~AnfExporter() {} virtual ~AnfExporter() {}
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs);
protected: protected:
virtual std::string GetNodeType(const AnfNodePtr& nd); virtual std::string GetNodeType(const AnfNodePtr &nd);
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr& param); int GetParamIndexFromExported(const AnfNodePtr &param);
std::string DumpObject(const py::object& obj, const std::string& category) const; std::string DumpObject(const py::object &obj, const std::string &category) const;
std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node); std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node);
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph); std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph);
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst); std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst);
std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value); std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetPrimitiveText(const PrimitivePtr& prim); std::string GetPrimitiveText(const PrimitivePtr &prim);
std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value); std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetNameSpaceText(const parse::NameSpacePtr& ns); std::string GetNameSpaceText(const parse::NameSpacePtr &ns);
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph); std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph);
std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, int>& apply_map); const std::map<AnfNodePtr, int> &apply_map);
void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph); void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map); OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map);
void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node); void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph); void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
int param_index; int param_index;
OrderedSet<FuncGraphPtr> func_graph_set{}; OrderedSet<FuncGraphPtr> func_graph_set{};
@ -108,16 +108,16 @@ class AnfExporter {
abstract::AnfNodeConfigPtr node_cfg_ = nullptr; abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
}; };
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph);
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs); void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs);
std::vector<FuncGraphPtr> ImportIR(const std::string& filename); std::vector<FuncGraphPtr> ImportIR(const std::string &filename);
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_

View File

@ -34,7 +34,7 @@ namespace draw {
namespace { namespace {
// Only for ValueNode // Only for ValueNode
std::string ValueType(const ValueNodePtr& node) { std::string ValueType(const ValueNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return ""; return "";
} }
@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) {
return v->type_name(); return v->type_name();
} }
std::string ReplaceSpecialChar(const std::string& str) { std::string ReplaceSpecialChar(const std::string &str) {
std::ostringstream oss; std::ostringstream oss;
for (size_t i = 0; i < str.size(); i++) { for (size_t i = 0; i < str.size(); i++) {
if (str[i] == '<') { if (str[i] == '<') {
@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) {
} // namespace } // namespace
// API of debug utils // API of debug utils
void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs, void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs,
bool is_user) { bool is_user) {
if (sub_graphs == nullptr) { if (sub_graphs == nullptr) {
return; return;
} }
for (auto& nd : nodes) { for (auto &nd : nodes) {
MS_EXCEPTION_IF_NULL(nd); MS_EXCEPTION_IF_NULL(nd);
auto sub_graph = nd->func_graph(); auto sub_graph = nd->func_graph();
if (sub_graph != nullptr) { if (sub_graph != nullptr) {
@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st
} }
} }
void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, void DrawValueNodes(const std::vector<AnfNodePtr> &nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) { OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) {
if (sub_graphs == nullptr) { if (sub_graphs == nullptr) {
return; return;
} }
int dup_idx = 0; int dup_idx = 0;
for (auto& nd : nodes) { for (auto &nd : nodes) {
for (auto& t : SuccIncoming(nd)) { for (auto &t : SuccIncoming(nd)) {
MS_EXCEPTION_IF_NULL(t); MS_EXCEPTION_IF_NULL(t);
MS_EXCEPTION_IF_NULL(nd); MS_EXCEPTION_IF_NULL(nd);
if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
} }
} }
void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) { void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) {
if (digraph == nullptr) { if (digraph == nullptr) {
return; return;
} }
@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
} }
// Draw edge // Draw edge
for (auto& nd : nodes) { for (auto &nd : nodes) {
auto succs = SuccIncoming(nd); auto succs = SuccIncoming(nd);
auto num = succs.size(); auto num = succs.size();
for (size_t i = 0; i < num; i++) { for (size_t i = 0; i < num; i++) {
auto& t = succs.at(i); auto &t = succs.at(i);
MS_EXCEPTION_IF_NULL(t); MS_EXCEPTION_IF_NULL(t);
if (t->isa<ValueNode>() || t->isa<Parameter>()) { if (t->isa<ValueNode>() || t->isa<Parameter>()) {
if ((!is_user) || (i != 0)) { if ((!is_user) || (i != 0)) {
@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
} }
} }
void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) { void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return; return;
} }
@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
DrawValueNodes(nodes, &sub_graphs); DrawValueNodes(nodes, &sub_graphs);
// Draw subgraph // Draw subgraph
for (const auto& gsub : sub_graphs) { for (const auto &gsub : sub_graphs) {
digraph->SubGraph(gsub.first, gsub.second); digraph->SubGraph(gsub.first, gsub.second);
} }
@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
} }
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
void Draw(const std::string& filename, const FuncGraphPtr& func_graph) { void Draw(const std::string &filename, const FuncGraphPtr &func_graph) {
const std::string dot_suffix = ".dot"; const std::string dot_suffix = ".dot";
std::string filename_with_suffix = std::string filename_with_suffix =
(filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
DrawByOpt(filename_with_suffix, func_graph, false); DrawByOpt(filename_with_suffix, func_graph, false);
} }
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
DrawByOpt(filename, func_graph, true); DrawByOpt(filename, func_graph, true);
} }
#else #else
void Draw(const std::string&, const FuncGraphPtr&) { void Draw(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false; static bool already_printed = false;
if (already_printed) { if (already_printed) {
return; return;
@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) {
<< "please recompile source to enable it. See help of building script."; << "please recompile source to enable it. See help of building script.";
} }
void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) { void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false; static bool already_printed = false;
if (already_printed) { if (already_printed) {
return; return;
@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) {
return "plaintext"; return "plaintext";
} }
std::string Graphviz::Color(const AnfNodePtr& node) { std::string Graphviz::Color(const AnfNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return ""; return "";
} }
@ -259,7 +259,7 @@ void BaseDigraph::Start() {
buffer_ << "compound=true" << std::endl; buffer_ << "compound=true" << std::endl;
} }
void BaseDigraph::Head(const AnfNodePtr& node, int id) { void BaseDigraph::Head(const AnfNodePtr &node, int id) {
if (node == nullptr) { if (node == nullptr) {
return; return;
} }
@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) {
} }
} }
void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) {
if (node == nullptr) { if (node == nullptr) {
return; return;
} }
@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
buffer_ << ":" << idx; buffer_ << ":" << idx;
} }
void BaseDigraph::Tail(const FuncGraphPtr& func_graph) { void BaseDigraph::Tail(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return; return;
} }
@ -304,12 +304,12 @@ void BaseDigraph::End() {
} }
} }
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << "parameters_" << key << "[shape=plaintext "; buffer_ << "parameters_" << key << "[shape=plaintext ";
buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>"; buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
buffer_ << "<tr><td>parameters</td></tr>"; buffer_ << "<tr><td>parameters</td></tr>";
int count = 0; int count = 0;
for (auto& parameter : key->parameters()) { for (auto &parameter : key->parameters()) {
buffer_ << "<tr><td>"; buffer_ << "<tr><td>";
buffer_ << parameter->ToString(); buffer_ << parameter->ToString();
auto py_p = dyn_cast<Parameter>(parameter)->default_param(); auto py_p = dyn_cast<Parameter>(parameter)->default_param();
@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
buffer_ << "</table>>,];"; buffer_ << "</table>>,];";
} }
void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) { void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) {
if (key == nullptr || gsub == nullptr) { if (key == nullptr || gsub == nullptr) {
return; return;
} }
@ -361,12 +361,12 @@ Digraph::~Digraph() {
if (fout_.is_open()) { if (fout_.is_open()) {
fout_.close(); fout_.close();
} }
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file " << filename_; MS_LOG(ERROR) << "Exception when closing file " << filename_;
} }
} }
static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) {
size_t start_pos = 0; size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) { while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
(void)str.replace(start_pos, from.length(), to); (void)str.replace(start_pos, from.length(), to);
@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st
return str; return str;
} }
static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph_obj); MS_EXCEPTION_IF_NULL(graph_obj);
graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node) graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
<< "'>"; << "'>";
@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</td></tr>"; graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "<tr><td align='left'>"; graph_obj->buffer() << "<tr><td align='left'>";
int i = 0; int i = 0;
for (const auto& attr : attrs) { for (const auto &attr : attrs) {
if (i != 0) { if (i != 0) {
graph_obj->buffer() << "<br/>"; graph_obj->buffer() << "<br/>";
} }
@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</table>>,"; graph_obj->buffer() << "</table>>,";
} }
static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) { if (graph_obj == nullptr || node == nullptr) {
return; return;
} }
@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
} }
} }
static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr || node->size() == 0) { if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
return; return;
} }
@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
} }
graph_obj->buffer() << ">"; graph_obj->buffer() << ">";
int i = 0; int i = 0;
for (auto& attr : attrs) { for (auto &attr : attrs) {
if (i != 0) { if (i != 0) {
graph_obj->buffer() << "<br/>"; graph_obj->buffer() << "<br/>";
} }
@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() {
if (fout_.is_open()) { if (fout_.is_open()) {
fout_.close(); fout_.close();
} }
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(ERROR) << "exception when closing file " << filename_; MS_LOG(ERROR) << "exception when closing file " << filename_;
} }
} }

View File

@ -31,9 +31,9 @@ namespace parse = mindspore::parse;
class Graphviz { class Graphviz {
public: public:
Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {} Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {}
explicit Graphviz(const std::string& name) : name_(name) {} explicit Graphviz(const std::string &name) : name_(name) {}
virtual ~Graphviz() {} virtual ~Graphviz() {}
@ -41,8 +41,8 @@ class Graphviz {
virtual void End() {} virtual void End() {}
virtual std::string Shape(AnfNodePtr node); virtual std::string Shape(AnfNodePtr node);
std::string Color(const AnfNodePtr& node); std::string Color(const AnfNodePtr &node);
std::ostringstream& buffer() { return buffer_; } std::ostringstream &buffer() { return buffer_; }
std::ostringstream buffer_; std::ostringstream buffer_;
protected: protected:
@ -53,8 +53,8 @@ class Graphviz {
class BaseDigraph : public Graphviz { class BaseDigraph : public Graphviz {
public: public:
BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {} BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {}
explicit BaseDigraph(const std::string& name) : Graphviz(name) {} explicit BaseDigraph(const std::string &name) : Graphviz(name) {}
~BaseDigraph() override = default; ~BaseDigraph() override = default;
virtual void Node(AnfNodePtr node, int id = 0) = 0; virtual void Node(AnfNodePtr node, int id = 0) = 0;
@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz {
void Start() override; void Start() override;
void End() override; void End() override;
virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start);
void FuncGraphParameters(const FuncGraphPtr& key); void FuncGraphParameters(const FuncGraphPtr &key);
void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub); void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub);
const std::string& name() const { return name_; } const std::string &name() const { return name_; }
protected: protected:
void Head(const AnfNodePtr& node, int id = 0); void Head(const AnfNodePtr &node, int id = 0);
void Tail(const AnfNodePtr& node, int idx, int id = 0); void Tail(const AnfNodePtr &node, int idx, int id = 0);
void Tail(const FuncGraphPtr& func_graph); void Tail(const FuncGraphPtr &func_graph);
}; };
class Digraph : public BaseDigraph { class Digraph : public BaseDigraph {
public: public:
Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit Digraph(const std::string& name) : BaseDigraph(name) {} explicit Digraph(const std::string &name) : BaseDigraph(name) {}
~Digraph() override; ~Digraph() override;
void Node(AnfNodePtr node, int id = 0) override; void Node(AnfNodePtr node, int id = 0) override;
@ -86,8 +86,8 @@ class Digraph : public BaseDigraph {
class ModelDigraph : public BaseDigraph { class ModelDigraph : public BaseDigraph {
public: public:
ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {} explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {}
~ModelDigraph() override; ~ModelDigraph() override;
std::string Shape(AnfNodePtr node) override; std::string Shape(AnfNodePtr node) override;
@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph {
}; };
// API to draw // API to draw
void Draw(const std::string& filename, const FuncGraphPtr& func_graph); void Draw(const std::string &filename, const FuncGraphPtr &func_graph);
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
} // namespace draw } // namespace draw
} // namespace mindspore } // namespace mindspore

View File

@ -33,38 +33,38 @@ class ProtoExporter {
ProtoExporter() {} ProtoExporter() {}
~ProtoExporter() {} ~ProtoExporter() {}
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
private: private:
void InitModelInfo(); void InitModelInfo();
void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto); void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto);
std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node, std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t>& apply_map, const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr); std::map<AnfNodePtr, size_t> *const_map_ptr);
void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto); void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto); void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto);
void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto); void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto);
void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto); void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto);
void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto); void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto);
void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto); void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto);
void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
std::map<AnfNodePtr, size_t>* const_map_ptr); std::map<AnfNodePtr, size_t> *const_map_ptr);
void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* apply_map_ptr, void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto); std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto);
void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
const std::map<AnfNodePtr, size_t>& apply_map, std::map<AnfNodePtr, size_t>* const_map_ptr, const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
irpb::GraphProto* graph_proto); irpb::GraphProto *graph_proto);
void ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto); void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto);
static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
irpb::ModelProto model_; irpb::ModelProto model_;
}; };
static irpb::DataType GetNumberDataType(const TypePtr& type) { static irpb::DataType GetNumberDataType(const TypePtr &type) {
switch (type->type_id()) { switch (type->type_id()) {
case kNumberTypeBool: case kNumberTypeBool:
return irpb::DT_BOOL; return irpb::DT_BOOL;
@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) {
} }
} }
void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) { void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
if (type_proto == nullptr) { if (type_proto == nullptr) {
return; return;
} }
@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
type_proto->set_data_type(irpb::DT_TENSOR); type_proto->set_data_type(irpb::DT_TENSOR);
if (shape != nullptr && shape->isa<abstract::Shape>()) { if (shape != nullptr && shape->isa<abstract::Shape>()) {
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
for (const auto& elem : shape_info->shape()) { for (const auto &elem : shape_info->shape()) {
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
} }
} }
} else if (type->isa<Tuple>()) { } else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type); TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE); type_proto->set_data_type(irpb::DT_TUPLE);
for (const auto& elem_type : tuple_type->elements()) { for (const auto &elem_type : tuple_type->elements()) {
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
} }
} else if (type->isa<TypeType>()) { } else if (type->isa<TypeType>()) {
@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
} else if (type->isa<List>()) { } else if (type->isa<List>()) {
ListPtr list_type = dyn_cast<List>(type); ListPtr list_type = dyn_cast<List>(type);
type_proto->set_data_type(irpb::DT_LIST); type_proto->set_data_type(irpb::DT_LIST);
for (const auto& elem_type : list_type->elements()) { for (const auto &elem_type : list_type->elements()) {
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
} }
} else if (type->isa<TypeAnything>()) { } else if (type->isa<TypeAnything>()) {
@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
} }
} }
void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) { void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) {
if (node == nullptr || type_proto == nullptr) { if (node == nullptr || type_proto == nullptr) {
return; return;
} }
SetNodeOutputType(node->Type(), node->Shape(), type_proto); SetNodeOutputType(node->Type(), node->Shape(), type_proto);
} }
void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) { void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) { if (val == nullptr || value_proto == nullptr) {
return; return;
} }
if (val->isa<StringImm>()) { if (val->isa<StringImm>()) {
const StringImmPtr& value = dyn_cast<StringImm>(val); const StringImmPtr &value = dyn_cast<StringImm>(val);
value_proto->set_dtype(irpb::DT_STRING); value_proto->set_dtype(irpb::DT_STRING);
value_proto->set_str_val(value->value()); value_proto->set_str_val(value->value());
} else if (val->isa<Scalar>()) { } else if (val->isa<Scalar>()) {
@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
} else if (val->isa<tensor::Tensor>()) { } else if (val->isa<tensor::Tensor>()) {
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val); tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
value_proto->set_dtype(irpb::DT_TENSOR); value_proto->set_dtype(irpb::DT_TENSOR);
irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val(); irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype())); tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype()));
for (auto& elem : tensor_ptr->shape()) { for (auto &elem : tensor_ptr->shape()) {
tensor_proto->add_dims(elem); tensor_proto->add_dims(elem);
} }
} else if (val->isa<TensorType>()) { } else if (val->isa<TensorType>()) {
value_proto->set_dtype(irpb::DT_TYPE); value_proto->set_dtype(irpb::DT_TYPE);
irpb::TypeProto* type_proto = value_proto->mutable_type_val(); irpb::TypeProto *type_proto = value_proto->mutable_type_val();
type_proto->set_data_type(irpb::DT_TENSOR); type_proto->set_data_type(irpb::DT_TENSOR);
TypePtr elem_type = dyn_cast<TensorType>(val)->element(); TypePtr elem_type = dyn_cast<TensorType>(val)->element();
type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
} }
} }
void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) { void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) { if (val == nullptr || value_proto == nullptr) {
return; return;
} }
if (val->isa<BoolImm>()) { if (val->isa<BoolImm>()) {
const BoolImmPtr& value = dyn_cast<BoolImm>(val); const BoolImmPtr &value = dyn_cast<BoolImm>(val);
value_proto->set_dtype(irpb::DT_BOOL); value_proto->set_dtype(irpb::DT_BOOL);
value_proto->set_bool_val(value->value()); value_proto->set_bool_val(value->value());
} else if (val->isa<Int8Imm>()) { } else if (val->isa<Int8Imm>()) {
const Int8ImmPtr& value = dyn_cast<Int8Imm>(val); const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
value_proto->set_dtype(irpb::DT_INT8); value_proto->set_dtype(irpb::DT_INT8);
value_proto->set_int_val(value->value()); value_proto->set_int_val(value->value());
} else if (val->isa<Int16Imm>()) { } else if (val->isa<Int16Imm>()) {
const Int16ImmPtr& value = dyn_cast<Int16Imm>(val); const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
value_proto->set_dtype(irpb::DT_INT16); value_proto->set_dtype(irpb::DT_INT16);
value_proto->set_int_val(value->value()); value_proto->set_int_val(value->value());
} else if (val->isa<Int32Imm>()) { } else if (val->isa<Int32Imm>()) {
const Int32ImmPtr& value = dyn_cast<Int32Imm>(val); const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
value_proto->set_dtype(irpb::DT_INT32); value_proto->set_dtype(irpb::DT_INT32);
value_proto->set_int_val(value->value()); value_proto->set_int_val(value->value());
} else if (val->isa<Int64Imm>()) { } else if (val->isa<Int64Imm>()) {
const Int64ImmPtr& value = dyn_cast<Int64Imm>(val); const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
value_proto->set_dtype(irpb::DT_INT64); value_proto->set_dtype(irpb::DT_INT64);
value_proto->set_int_val(value->value()); value_proto->set_int_val(value->value());
} else if (val->isa<UInt8Imm>()) { } else if (val->isa<UInt8Imm>()) {
const UInt8ImmPtr& value = dyn_cast<UInt8Imm>(val); const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
value_proto->set_dtype(irpb::DT_UINT8); value_proto->set_dtype(irpb::DT_UINT8);
value_proto->set_uint_val(value->value()); value_proto->set_uint_val(value->value());
} else if (val->isa<UInt16Imm>()) { } else if (val->isa<UInt16Imm>()) {
const UInt16ImmPtr& value = dyn_cast<UInt16Imm>(val); const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
value_proto->set_dtype(irpb::DT_UINT16); value_proto->set_dtype(irpb::DT_UINT16);
value_proto->set_uint_val(value->value()); value_proto->set_uint_val(value->value());
} else if (val->isa<UInt32Imm>()) { } else if (val->isa<UInt32Imm>()) {
const UInt32ImmPtr& value = dyn_cast<UInt32Imm>(val); const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
value_proto->set_dtype(irpb::DT_UINT32); value_proto->set_dtype(irpb::DT_UINT32);
value_proto->set_uint_val(value->value()); value_proto->set_uint_val(value->value());
} else if (val->isa<UInt64Imm>()) { } else if (val->isa<UInt64Imm>()) {
const UInt64ImmPtr& value = dyn_cast<UInt64Imm>(val); const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
value_proto->set_dtype(irpb::DT_UINT64); value_proto->set_dtype(irpb::DT_UINT64);
value_proto->set_uint_val(value->value()); value_proto->set_uint_val(value->value());
} else if (val->isa<FP32Imm>()) { } else if (val->isa<FP32Imm>()) {
const FP32ImmPtr& value = dyn_cast<FP32Imm>(val); const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
value_proto->set_dtype(irpb::DT_FLOAT32); value_proto->set_dtype(irpb::DT_FLOAT32);
value_proto->set_float_val(value->value()); value_proto->set_float_val(value->value());
} else if (val->isa<FP64Imm>()) { } else if (val->isa<FP64Imm>()) {
const FP64ImmPtr& value = dyn_cast<FP64Imm>(val); const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
value_proto->set_dtype(irpb::DT_FLOAT64); value_proto->set_dtype(irpb::DT_FLOAT64);
value_proto->set_double_val(value->value()); value_proto->set_double_val(value->value());
} else { } else {
@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val
} }
} }
void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) { void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) { if (val == nullptr || value_proto == nullptr) {
return; return;
} }
if (val->isa<ValueTuple>()) { if (val->isa<ValueTuple>()) {
const ValueTuplePtr& value = dyn_cast<ValueTuple>(val); const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
value_proto->set_dtype(irpb::DT_TUPLE); value_proto->set_dtype(irpb::DT_TUPLE);
for (const auto& item : value->value()) { for (const auto &item : value->value()) {
SetValueToProto(item, value_proto->add_values()); SetValueToProto(item, value_proto->add_values());
} }
} else if (val->isa<ValueList>()) { } else if (val->isa<ValueList>()) {
const ValueListPtr& value = dyn_cast<ValueList>(val); const ValueListPtr &value = dyn_cast<ValueList>(val);
value_proto->set_dtype(irpb::DT_LIST); value_proto->set_dtype(irpb::DT_LIST);
for (const auto& item : value->value()) { for (const auto &item : value->value()) {
SetValueToProto(item, value_proto->add_values()); SetValueToProto(item, value_proto->add_values());
} }
} }
} }
void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) { void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) { if (val == nullptr || value_proto == nullptr) {
return; return;
} }
value_proto->set_dtype(irpb::DT_DICT); value_proto->set_dtype(irpb::DT_DICT);
for (const auto& item : val->value()) { for (const auto &item : val->value()) {
irpb::NamedValueProto* named_val = value_proto->add_dict_val(); irpb::NamedValueProto *named_val = value_proto->add_dict_val();
named_val->set_key(item.first); named_val->set_key(item.first);
SetValueToProto(item.second, named_val->mutable_value()); SetValueToProto(item.second, named_val->mutable_value());
} }
} }
void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) { void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) {
if (node == nullptr || node_proto == nullptr) { if (node == nullptr || node_proto == nullptr) {
return; return;
} }
@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr&
MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
} }
const PrimitivePtr& prim = GetValueNode<PrimitivePtr>(node); const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
node_proto->set_op_type(prim->name()); node_proto->set_op_type(prim->name());
for (const auto& attr : prim->attrs()) { for (const auto &attr : prim->attrs()) {
irpb::AttributeProto* attr_proto = node_proto->add_attribute(); irpb::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr.first); attr_proto->set_name(attr.first);
SetValueToProto(attr.second, attr_proto->mutable_value()); SetValueToProto(attr.second, attr_proto->mutable_value());
} }
node_proto->set_scope(node->scope()->name()); node_proto->set_scope(node->scope()->name());
} }
std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node, std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t>& apply_map, const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr) { std::map<AnfNodePtr, size_t> *const_map_ptr) {
if (node == nullptr || const_map_ptr == nullptr) { if (node == nullptr || const_map_ptr == nullptr) {
return ""; return "";
} }
@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt
MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
} }
std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return ""; return "";
} }
InitModelInfo(); InitModelInfo();
irpb::GraphProto* graph_proto = model_.mutable_graph(); irpb::GraphProto *graph_proto = model_.mutable_graph();
ExportFuncGraph(func_graph, graph_proto); ExportFuncGraph(func_graph, graph_proto);
return model_.SerializeAsString(); return model_.SerializeAsString();
} }
void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || graph_proto == nullptr) { if (func_graph == nullptr || graph_proto == nullptr) {
return; return;
} }
@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP
ExportValueNodes(const_map, graph_proto); ExportValueNodes(const_map, graph_proto);
} }
void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || graph_proto == nullptr) { if (func_graph == nullptr || graph_proto == nullptr) {
return; return;
} }
std::vector<AnfNodePtr> parameters = func_graph->parameters(); std::vector<AnfNodePtr> parameters = func_graph->parameters();
for (auto& param : parameters) { for (auto &param : parameters) {
irpb::ParameterProto* param_proto = graph_proto->add_parameters(); irpb::ParameterProto *param_proto = graph_proto->add_parameters();
param_proto->set_name(param->ToString()); param_proto->set_name(param->ToString());
SetNodeOutputType(param, param_proto->mutable_type()); SetNodeOutputType(param, param_proto->mutable_type());
@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph
} }
} }
void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
std::map<AnfNodePtr, size_t>* const_map_ptr) { std::map<AnfNodePtr, size_t> *const_map_ptr) {
if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
return; return;
} }
// topo sort nodes // topo sort nodes
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
std::map<AnfNodePtr, size_t> apply_map; std::map<AnfNodePtr, size_t> apply_map;
for (const AnfNodePtr& node : nodes) { for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt
} }
} }
void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* apply_map_ptr, std::map<AnfNodePtr, size_t> *apply_map_ptr,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) { std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
graph_proto == nullptr) { graph_proto == nullptr) {
return; return;
@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
auto apply_idx = apply_map_ptr->size() + 1; auto apply_idx = apply_map_ptr->size() + 1;
(*apply_map_ptr)[node] = apply_idx; (*apply_map_ptr)[node] = apply_idx;
auto& inputs = node->inputs(); auto &inputs = node->inputs();
if (inputs.size() < 1) { if (inputs.size() < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
} }
AnfNodePtr op = inputs[0]; AnfNodePtr op = inputs[0];
irpb::NodeProto* node_proto = graph_proto->add_node(); irpb::NodeProto *node_proto = graph_proto->add_node();
// CNode/ConstGraph/Const/Parameter // CNode/ConstGraph/Const/Parameter
if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) { if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
// process OP inputs // process OP inputs
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
irpb::InputProto* input_proto = node_proto->add_input(); irpb::InputProto *input_proto = node_proto->add_input();
input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE); input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE);
std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
input_proto->set_name(id); input_proto->set_name(id);
@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
} }
} }
void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
const std::map<AnfNodePtr, size_t>& apply_map, const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) { std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
if (ret_node == nullptr || !ret_node->isa<CNode>()) { if (ret_node == nullptr || !ret_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Graph return node is illegal"; MS_LOG(EXCEPTION) << "Graph return node is illegal";
} }
@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const
if (graph_proto == nullptr) { if (graph_proto == nullptr) {
MS_LOG(EXCEPTION) << "graph_proto is nullptr"; MS_LOG(EXCEPTION) << "graph_proto is nullptr";
} }
irpb::OutputProto* output_proto = graph_proto->add_outputs(); irpb::OutputProto *output_proto = graph_proto->add_outputs();
if (output_proto == nullptr) { if (output_proto == nullptr) {
MS_LOG(EXCEPTION) << "output_proto is nullptr"; MS_LOG(EXCEPTION) << "output_proto is nullptr";
} }
@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const
SetNodeOutputType(arg, output_proto->mutable_type()); SetNodeOutputType(arg, output_proto->mutable_type());
} }
static bool CompareValue(const std::pair<AnfNodePtr, size_t>& x, const std::pair<AnfNodePtr, size_t>& y) { static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
return x.second < y.second; return x.second < y.second;
} }
void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto) { void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) {
std::vector<std::pair<AnfNodePtr, size_t>> nodes; std::vector<std::pair<AnfNodePtr, size_t>> nodes;
(void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
[](const std::pair<AnfNodePtr, size_t>& item) { return item; }); [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
sort(nodes.begin(), nodes.end(), CompareValue); sort(nodes.begin(), nodes.end(), CompareValue);
for (auto& item : nodes) { for (auto &item : nodes) {
if (graph_proto == nullptr) { if (graph_proto == nullptr) {
MS_LOG(EXCEPTION) << "graph_proto is nullptr"; MS_LOG(EXCEPTION) << "graph_proto is nullptr";
} }
irpb::NamedValueProto* named_value = graph_proto->add_const_vals(); irpb::NamedValueProto *named_value = graph_proto->add_const_vals();
MS_EXCEPTION_IF_NULL(named_value); MS_EXCEPTION_IF_NULL(named_value);
named_value->set_key(GetConstNodeId(item.second)); named_value->set_key(GetConstNodeId(item.second));
SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_m
void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); } void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); }
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
ProtoExporter exporter; ProtoExporter exporter;
return exporter.GetFuncGraphProtoString(func_graph); return exporter.GetFuncGraphProtoString(func_graph);
} }

View File

@ -36,7 +36,7 @@ Dump::Dump()
dump_iter_(0), dump_iter_(0),
cur_iter_(0) {} cur_iter_(0) {}
bool Dump::IsKernelNeedDump(const std::string& kernel_name) { bool Dump::IsKernelNeedDump(const std::string &kernel_name) {
if (dump_mode_ == 0) { if (dump_mode_ == 0) {
// Dump All Kernels mode // Dump All Kernels mode
return true; return true;
@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
return false; return false;
} }
bool Dump::ParseDumpConfig(const std::string& dump_config_file) { bool Dump::ParseDumpConfig(const std::string &dump_config_file) {
std::ifstream jsonFile(dump_config_file); std::ifstream jsonFile(dump_config_file);
if (!jsonFile.is_open()) { if (!jsonFile.is_open()) {
MS_LOG(ERROR) << dump_config_file << " open failed."; MS_LOG(ERROR) << dump_config_file << " open failed.";
@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
return true; return true;
} }
bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) {
if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() ||
dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() ||
dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() ||
@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
return true; return true;
} }
bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto trans_flag = dumpSettings.at("trans_flag"); auto trans_flag = dumpSettings.at("trans_flag");
auto enable = dumpSettings.at("enable"); auto enable = dumpSettings.at("enable");
auto mode = dumpSettings.at("mode"); auto mode = dumpSettings.at("mode");
@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
dump_path_ = path; dump_path_ = path;
dump_net_name_ = net_name; dump_net_name_ = net_name;
dump_iter_ = iteration; dump_iter_ = iteration;
for (const auto& kernel : kernels) { for (const auto &kernel : kernels) {
dump_kernels_.push_back(kernel); dump_kernels_.push_back(kernel);
} }
return true; return true;
} }
bool Dump::SetDumpConfFromJsonFile() { bool Dump::SetDumpConfFromJsonFile() {
const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
if (config_path_str != nullptr) { if (config_path_str != nullptr) {
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
} else { } else {
@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() {
return ParseDumpConfig(dump_config_file); return ParseDumpConfig(dump_config_file);
} }
bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) { bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) {
if (filename.empty() || data == nullptr || len == 0) { if (filename.empty() || data == nullptr || len == 0) {
MS_LOG(ERROR) << "Incorrect parameter."; MS_LOG(ERROR) << "Incorrect parameter.";
return false; return false;
@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len)
MS_LOG(ERROR) << "Open file " << realpath << " fail."; MS_LOG(ERROR) << "Open file " << realpath << " fail.";
return false; return false;
} }
(void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len)); (void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len));
fd.close(); fd.close();
return true; return true;
} }
bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) {
MS_EXCEPTION_IF_NULL(outpath); MS_EXCEPTION_IF_NULL(outpath);
auto path_split_pos = inpath.find_last_of('/'); auto path_split_pos = inpath.find_last_of('/');
if (path_split_pos == std::string::npos) { if (path_split_pos == std::string::npos) {
@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
return true; return true;
} }
bool Dump::CreateNotExistDirs(const std::string& path) { bool Dump::CreateNotExistDirs(const std::string &path) {
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem(); std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs); MS_EXCEPTION_IF_NULL(fs);
char temp_path[PATH_MAX] = {0}; char temp_path[PATH_MAX] = {0};

View File

@ -43,11 +43,11 @@ class Dump {
uint32_t cur_iter() const { return cur_iter_; } uint32_t cur_iter() const { return cur_iter_; }
bool IsKernelNeedDump(const std::string& kernel_name); bool IsKernelNeedDump(const std::string &kernel_name);
bool SetDumpConfFromJsonFile(); bool SetDumpConfFromJsonFile();
static bool DumpToFile(const std::string& filename, const void* data, size_t len); static bool DumpToFile(const std::string &filename, const void *data, size_t len);
protected: protected:
bool dump_enable_; bool dump_enable_;
@ -59,14 +59,14 @@ class Dump {
uint32_t cur_iter_; uint32_t cur_iter_;
std::vector<std::string> dump_kernels_; std::vector<std::string> dump_kernels_;
static bool GetRealPath(const std::string& inpath, std::string* outpath); static bool GetRealPath(const std::string &inpath, std::string *outpath);
static bool CreateNotExistDirs(const std::string& path); static bool CreateNotExistDirs(const std::string &path);
private: private:
bool ParseDumpConfig(const std::string& dump_config_file); bool ParseDumpConfig(const std::string &dump_config_file);
bool IsConfigExist(const nlohmann::json& dumpSettings); bool IsConfigExist(const nlohmann::json &dumpSettings);
bool IsConfigValid(const nlohmann::json& dumpSettings); bool IsConfigValid(const nlohmann::json &dumpSettings);
}; };
using DumpConfPtr = std::shared_ptr<Dump>; using DumpConfPtr = std::shared_ptr<Dump>;

View File

@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h" #include "pipeline/parse/python_adapter.h"
namespace mindspore { namespace mindspore {
std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) { std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {
std::string temp_line = line; std::string temp_line = line;
if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) &&
tip != kSourceLineTipDiscard) { tip != kSourceLineTipDiscard) {
@ -101,14 +101,14 @@ DebugInfo::DebugInfo() {
name_ = ""; name_ = "";
} }
DebugInfo::DebugInfo(const std::string& name) { DebugInfo::DebugInfo(const std::string &name) {
InitValueFromContext(); InitValueFromContext();
unique_id_ = gen_unique_id(); unique_id_ = gen_unique_id();
debug_id_ = -1; debug_id_ = -1;
name_ = name; name_ = name;
} }
DebugInfo::DebugInfo(const LocationPtr& loc) { DebugInfo::DebugInfo(const LocationPtr &loc) {
InitValueFromContext(); InitValueFromContext();
unique_id_ = gen_unique_id(); unique_id_ = gen_unique_id();
debug_id_ = -1; debug_id_ = -1;
@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() {
} }
int64_t DebugInfo::unique_id_through_copy() const { int64_t DebugInfo::unique_id_through_copy() const {
TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info(); TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info();
if (trace_info != nullptr) { if (trace_info != nullptr) {
if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) { if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) {
return trace_info->debug_info()->unique_id_through_copy(); return trace_info->debug_info()->unique_id_through_copy();
@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() {
} }
return DebugInfo::location(); return DebugInfo::location();
} }
void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; } void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; }
TraceContextPtr TraceManager::CurrentContextInfo() { TraceContextPtr TraceManager::CurrentContextInfo() {
if (!TraceManager::trace_context_stack_.empty()) { if (!TraceManager::trace_context_stack_.empty()) {
@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() {
return nullptr; return nullptr;
} }
void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) { void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location); TraceContextPtr context = std::make_shared<TraceContext>(location);
context->set_func_name(func_name); context->set_func_name(func_name);
TraceManager::trace_context_stack_.push(context); TraceManager::trace_context_stack_.push(context);
} }
void TraceManager::DebugTrace(const LocationPtr& location) { void TraceManager::DebugTrace(const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location); TraceContextPtr context = std::make_shared<TraceContext>(location);
TraceManager::trace_context_stack_.push(context); TraceManager::trace_context_stack_.push(context);
} }
void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) { if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
} }
@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
TraceManager::trace_context_stack_.push(context); TraceManager::trace_context_stack_.push(context);
} }
void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) { void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) { if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
} }

View File

@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou
// Location class record the location in source code. // Location class record the location in source code.
class Location { class Location {
public: public:
Location(const std::string& file_name, int line, int column, int line_end, int column_end) Location(const std::string &file_name, int line, int column, int line_end, int column_end)
: file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {}
Location(const Location& loc) Location(const Location &loc)
: file_name_(loc.file_name_), : file_name_(loc.file_name_),
line_(loc.line_), line_(loc.line_),
column_(loc.column_), column_(loc.column_),
@ -77,21 +77,21 @@ class TraceManager {
TraceManager() = default; TraceManager() = default;
~TraceManager() = default; ~TraceManager() = default;
static TraceContextPtr CurrentContextInfo(); static TraceContextPtr CurrentContextInfo();
static void DebugTrace(const std::string& func_name, const LocationPtr& location); static void DebugTrace(const std::string &func_name, const LocationPtr &location);
static void DebugTrace(const LocationPtr& location); static void DebugTrace(const LocationPtr &location);
static void DebugTrace(const TraceInfoPtr& trace_info); static void DebugTrace(const TraceInfoPtr &trace_info);
// debug trace with a cloned trace info with debug_info // debug trace with a cloned trace info with debug_info
static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info); static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info);
static void EndTrace(); static void EndTrace();
static std::stack<TraceContextPtr> trace_context_stack_; static std::stack<TraceContextPtr> trace_context_stack_;
}; };
class TraceGuard { class TraceGuard {
public: public:
explicit TraceGuard(const std::string func_name, const LocationPtr& location) { explicit TraceGuard(const std::string func_name, const LocationPtr &location) {
TraceManager::DebugTrace(func_name, location); TraceManager::DebugTrace(func_name, location);
} }
explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); } explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); }
~TraceGuard() { TraceManager::EndTrace(); } ~TraceGuard() { TraceManager::EndTrace(); }
}; };
@ -106,23 +106,23 @@ class TraceContext {
public: public:
~TraceContext() = default; ~TraceContext() = default;
explicit TraceContext(const LocationPtr& loc) { explicit TraceContext(const LocationPtr &loc) {
ProcessAttributeFromContext(); ProcessAttributeFromContext();
location_ = loc; location_ = loc;
} }
explicit TraceContext(const std::string& func_name) { explicit TraceContext(const std::string &func_name) {
ProcessAttributeFromContext(); ProcessAttributeFromContext();
func_name_ = func_name; func_name_ = func_name;
} }
explicit TraceContext(const TraceInfoPtr& trace_info) { explicit TraceContext(const TraceInfoPtr &trace_info) {
ProcessAttributeFromContext(); ProcessAttributeFromContext();
trace_info_ = trace_info; trace_info_ = trace_info;
} }
void set_location(const LocationPtr& loc) { location_ = loc; } void set_location(const LocationPtr &loc) { location_ = loc; }
LocationPtr location() { return location_; } LocationPtr location() { return location_; }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; } TraceInfoPtr trace_info() { return trace_info_; }
void set_func_name(const std::string& func_name) { func_name_ = func_name; } void set_func_name(const std::string &func_name) { func_name_ = func_name; }
std::string func_name() { return func_name_; } std::string func_name() { return func_name_; }
}; };
@ -130,9 +130,9 @@ class DebugInfo : public Base {
public: public:
DebugInfo(); DebugInfo();
explicit DebugInfo(const std::string& name); explicit DebugInfo(const std::string &name);
explicit DebugInfo(const LocationPtr& loc); explicit DebugInfo(const LocationPtr &loc);
virtual ~DebugInfo() = default; virtual ~DebugInfo() = default;
MS_DECLARE_PARENT(DebugInfo, Base); MS_DECLARE_PARENT(DebugInfo, Base);
@ -141,12 +141,12 @@ class DebugInfo : public Base {
int64_t unique_id_through_copy() const; int64_t unique_id_through_copy() const;
std::string get_id() { return std::to_string(debug_id()); } std::string get_id() { return std::to_string(debug_id()); }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; } TraceInfoPtr trace_info() { return trace_info_; }
void set_location(const LocationPtr& loc) { location_ = loc; } void set_location(const LocationPtr &loc) { location_ = loc; }
virtual LocationPtr location() { return location_; } virtual LocationPtr location() { return location_; }
std::string name() { return name_; } std::string name() { return name_; }
void set_name(const std::string& name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
virtual std::string debug_name(); virtual std::string debug_name();
virtual std::string get_python_func_belonged() { return ""; } virtual std::string get_python_func_belonged() { return ""; }
@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo {
py_func_belonged_ = context_info->func_name(); py_func_belonged_ = context_info->func_name();
} }
} }
explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) { explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) { if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo(); auto context_info = TraceManager::CurrentContextInfo();
py_func_belonged_ = context_info->func_name(); py_func_belonged_ = context_info->func_name();
@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo {
~NodeDebugInfo() override = default; ~NodeDebugInfo() override = default;
std::string debug_name() override; std::string debug_name() override;
void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); } void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); }
std::shared_ptr<AnfNode> get_node() const { return node_.lock(); } std::shared_ptr<AnfNode> get_node() const { return node_.lock(); }
void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; } void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; }
std::string get_python_func_belonged() override { return py_func_belonged_; } std::string get_python_func_belonged() override { return py_func_belonged_; }
AnfNodeWeakPtr node_; AnfNodeWeakPtr node_;
std::string py_func_belonged_; std::string py_func_belonged_;
@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo {
} }
} }
explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) { explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) { if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo(); auto context_info = TraceManager::CurrentContextInfo();
py_func_name_ = context_info->func_name(); py_func_name_ = context_info->func_name();
@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo {
std::string debug_name() override; std::string debug_name() override;
LocationPtr location() override; LocationPtr location() override;
LocationPtr deco_location() { return deco_loc_; } LocationPtr deco_location() { return deco_loc_; }
void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
FuncGraphPtr get_graph() const { return func_graph_.lock(); } FuncGraphPtr get_graph() const { return func_graph_.lock(); }
void set_full_name(const std::string& name) { full_name_ = name; } void set_full_name(const std::string &name) { full_name_ = name; }
std::string get_full_name() { return full_name_; } std::string get_full_name() { return full_name_; }
void set_deco_location(const LocationPtr& deco_list_loc); void set_deco_location(const LocationPtr &deco_list_loc);
std::string get_python_func_belonged() override { return py_func_name_; } std::string get_python_func_belonged() override { return py_func_name_; }
FuncGraphWeakPtr func_graph_; FuncGraphWeakPtr func_graph_;
LocationPtr deco_loc_; LocationPtr deco_loc_;

View File

@ -31,7 +31,7 @@ struct NameWithTrace {
std::string name; std::string name;
std::vector<std::string> trace_labels; std::vector<std::string> trace_labels;
}; };
static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) { static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) {
switch (trace_label) { switch (trace_label) {
case TraceLabelType::kShortSymbol: case TraceLabelType::kShortSymbol:
return trace_info->symbol(); return trace_info->symbol();
@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t
} }
} }
NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name; NameWithTrace trace_name;
// find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node
auto temp_info = debug_info; auto temp_info = debug_info;
@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe
return trace_name; return trace_name;
} }
std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) { std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) {
std::string tags = ""; std::string tags = "";
for (auto& itr : trace_labels) { for (auto &itr : trace_labels) {
std::string symbol = itr; std::string symbol = itr;
tags = tags + symbol; tags = tags + symbol;
} }
@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st
} }
// get the label name of the node debug info // get the label name of the node debug info
std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name = RootName(debug_info, trace_label); NameWithTrace trace_name = RootName(debug_info, trace_label);
return CombineTraceTypes(trace_name.name, trace_name.trace_labels); return CombineTraceTypes(trace_name.name, trace_name.trace_labels);
} }
std::string CombineUniqueID(const DebugInfoPtr& debug_info) { std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
auto temp_info = debug_info; auto temp_info = debug_info;
std::string label = ""; std::string label = "";
while (temp_info != nullptr) { while (temp_info != nullptr) {
@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
} }
// get trace with unique id chain // get trace with unique id chain
std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); } std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) { if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
return LabelStringUnique(debug_info); return LabelStringUnique(debug_info);
} }

View File

@ -29,7 +29,7 @@ namespace label_manage {
enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId }; enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId };
TraceLabelType GetGlobalTraceLabelType(); TraceLabelType GetGlobalTraceLabelType();
void SetGlobalTraceLabelType(TraceLabelType label_type); void SetGlobalTraceLabelType(TraceLabelType label_type);
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol); std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
} // namespace label_manage } // namespace label_manage
} // namespace mindspore } // namespace mindspore

View File

@ -37,7 +37,7 @@
namespace mindspore { namespace mindspore {
// namespace to support debug trace infomation // namespace to support debug trace infomation
namespace trace { namespace trace {
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) { std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) {
if (abs == nullptr) { if (abs == nullptr) {
return "Null Abstract"; return "Null Abstract";
} }
@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) {
return debug_with_loc_vec; return debug_with_loc_vec;
} }
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) {
auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info);
if (debug_with_loc_vec.size() > 0) { if (debug_with_loc_vec.size() > 0) {
return debug_with_loc_vec[0]; return debug_with_loc_vec[0];
@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
} }
} }
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) { if (info == nullptr) {
return ""; return "";
} }
@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
// a trace info identifies a node transform, so we can trace the node transform through // a trace info identifies a node transform, so we can trace the node transform through
// a link of trace info and debug info // a link of trace info and debug info
std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) { std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) {
if (info_vec.size() < 1) { if (info_vec.size() < 1) {
return ""; return "";
} }
@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL
return traced_info; return traced_info;
} }
std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) { if (info == nullptr) {
return ""; return "";
} }
@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
return ""; return "";
} }
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) { std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) {
std::ostringstream oss; std::ostringstream oss;
if (info == nullptr) { if (info == nullptr) {
return ""; return "";
@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So
return oss.str(); return oss.str();
} }
std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) { std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) {
std::ostringstream oss; std::ostringstream oss;
oss << "graph:" << graph->ToString() << " with args["; oss << "graph:" << graph->ToString() << " with args[";
auto params = graph->parameters(); auto params = graph->parameters();
@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas
return oss.str(); return oss.str();
} }
void DumpInferStack(std::ostringstream& oss) { void DumpInferStack(std::ostringstream &oss) {
auto& infer_stack = GetCurrenGraphInferStack(); auto &infer_stack = GetCurrenGraphInferStack();
if (infer_stack.empty()) { if (infer_stack.empty()) {
return; return;
} }
@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) {
} }
std::reverse(infer_vec.begin(), infer_vec.end()); std::reverse(infer_vec.begin(), infer_vec.end());
int index = 0; int index = 0;
for (auto& item : infer_vec) { for (auto &item : infer_vec) {
auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first); auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first);
if (graph_infer == nullptr) { if (graph_infer == nullptr) {
MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator";
@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) {
} }
void TraceGraphInfer() { void TraceGraphInfer() {
auto& infer_stack = GetCurrenGraphInferStack(); auto &infer_stack = GetCurrenGraphInferStack();
std::ostringstream oss; std::ostringstream oss;
if (infer_stack.empty()) { if (infer_stack.empty()) {
return; return;
@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
~AnalyzedFuncGraphExporter() override = default; ~AnalyzedFuncGraphExporter() override = default;
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs); void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs);
private: private:
std::string GetNodeType(const AnfNodePtr& nd) override; std::string GetNodeType(const AnfNodePtr &nd) override;
}; };
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() { std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs; std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
auto& list = GetCNodeDebugStack(); auto &list = GetCNodeDebugStack();
for (size_t i = 0; i < list.size(); ++i) { for (size_t i = 0; i < list.size(); ++i) {
auto node_cfg = list[i]; auto node_cfg = list[i];
auto fg = node_cfg->context()->func_graph(); auto fg = node_cfg->context()->func_graph();
@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() {
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
} }
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
if (node_cfg_ == nullptr) { if (node_cfg_ == nullptr) {
return AnfExporter::GetNodeType(node); return AnfExporter::GetNodeType(node);
} }
@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
return oss.str(); return oss.str();
} }
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) { const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) {
if (node_cfgs.empty()) { if (node_cfgs.empty()) {
MS_LOG(DEBUG) << "Node configs is empty"; MS_LOG(DEBUG) << "Node configs is empty";
return; return;
@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
auto tagged_func_graphs = CalcTaggedFuncGraphs(); auto tagged_func_graphs = CalcTaggedFuncGraphs();
// first output graph on the analysis stack // first output graph on the analysis stack
for (const auto& node_cfg : node_cfgs) { for (const auto &node_cfg : node_cfgs) {
auto fg = node_cfg->context()->func_graph(); auto fg = node_cfg->context()->func_graph();
// the graph is already output, skip it // the graph is already output, skip it
if (exported.find(fg) != exported.end()) { if (exported.find(fg) != exported.end()) {
@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
ofs.close(); ofs.close();
} }
void GetInferStackInfo(std::ostringstream& oss) { void GetInferStackInfo(std::ostringstream &oss) {
MS_LOG(INFO) << "Get graph analysis information begin"; MS_LOG(INFO) << "Get graph analysis information begin";
auto stack = GetCNodeDebugStack(); auto stack = GetCNodeDebugStack();
if (stack.empty()) { if (stack.empty()) {
@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) {
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
// trace the cnode infer debug info // trace the cnode infer debug info
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) { void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
if (eval == nullptr) { if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
} }
@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An
} }
} }
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) {
if (eval == nullptr) { if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
} }
@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
} }
} }
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); } void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); }
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; } std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; }
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() { std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() {
return graph_infer_stack; return graph_infer_stack;
} }
void ClearTraceStack() { void ClearTraceStack() {

View File

@ -31,19 +31,19 @@
namespace mindspore { namespace mindspore {
namespace trace { namespace trace {
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine); std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
SourceLineTip tip = kSourceLineTipNextLine); SourceLineTip tip = kSourceLineTipNextLine);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info); DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
void TraceGraphInfer(); void TraceGraphInfer();
void GetInferStackInfo(std::ostringstream& oss); void GetInferStackInfo(std::ostringstream &oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node); void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval); void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg); void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg);
void TraceInferCNodeLeave(); void TraceInferCNodeLeave();
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack(); std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack(); std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs); std::string GetAbstractStr(const abstract::AbstractBasePtr &abs);
void ClearTraceStack(); void ClearTraceStack();
} // namespace trace } // namespace trace
} // namespace mindspore } // namespace mindspore

View File

@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h" #include "pipeline/parse/python_adapter.h"
namespace mindspore { namespace mindspore {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) { std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {
if (info == nullptr) { if (info == nullptr) {
return ""; return "";
} }

View File

@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr<DebugInfo>;
// namespace to support intermediate representation definition // namespace to support intermediate representation definition
class TraceInfo : public Base { class TraceInfo : public Base {
public: public:
TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) { TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) {
symbol_ = symbol; symbol_ = symbol;
full_name_ = full_name; full_name_ = full_name;
name_ = full_name_; name_ = full_name_;
debug_info_ = info; debug_info_ = info;
} }
TraceInfo(const TraceInfo& info) TraceInfo(const TraceInfo &info)
: Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {}
virtual ~TraceInfo() = default; virtual ~TraceInfo() = default;
MS_DECLARE_PARENT(TraceInfo, Base); MS_DECLARE_PARENT(TraceInfo, Base);
@ -55,8 +55,8 @@ class TraceInfo : public Base {
virtual std::string full_name() { return full_name_; } virtual std::string full_name() { return full_name_; }
virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); } virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); }
virtual std::string action_name() { return ""; } virtual std::string action_name() { return ""; }
virtual std::string GetActionBetweenNode(const DebugInfoPtr& info); virtual std::string GetActionBetweenNode(const DebugInfoPtr &info);
void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; } void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; }
DebugInfoPtr debug_info() { return debug_info_; } DebugInfoPtr debug_info() { return debug_info_; }
DebugInfoPtr DebugInfoHasLoc(); DebugInfoPtr DebugInfoHasLoc();
std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo(); std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo();
@ -70,7 +70,7 @@ class TraceInfo : public Base {
class TracePhi : public TraceInfo { class TracePhi : public TraceInfo {
public: public:
explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {} explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {}
MS_DECLARE_PARENT(TracePhi, TraceInfo); MS_DECLARE_PARENT(TracePhi, TraceInfo);
~TracePhi() override = default; ~TracePhi() override = default;
TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); } TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); }
@ -78,8 +78,8 @@ class TracePhi : public TraceInfo {
class TraceIfStmtTrueBranch : public TraceInfo { class TraceIfStmtTrueBranch : public TraceInfo {
public: public:
TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default; TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default;
explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "") {} explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "") {}
MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo);
~TraceIfStmtTrueBranch() override = default; ~TraceIfStmtTrueBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo {
class TraceIfStmtFalseBranch : public TraceInfo { class TraceIfStmtFalseBranch : public TraceInfo {
public: public:
TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default; TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default;
explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "") {} explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "") {}
MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo);
~TraceIfStmtFalseBranch() override = default; ~TraceIfStmtFalseBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo {
class TraceIfStmtAfterBranch : public TraceInfo { class TraceIfStmtAfterBranch : public TraceInfo {
public: public:
explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "") {} explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "") {}
MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo);
~TraceIfStmtAfterBranch() override = default; ~TraceIfStmtAfterBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo {
class TraceIfExpTrueBranch : public TraceInfo { class TraceIfExpTrueBranch : public TraceInfo {
public: public:
explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "") {} explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "") {}
MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo);
~TraceIfExpTrueBranch() override = default; ~TraceIfExpTrueBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo {
class TraceIfExpFalseBranch : public TraceInfo { class TraceIfExpFalseBranch : public TraceInfo {
public: public:
explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "") {} explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "") {}
MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo);
~TraceIfExpFalseBranch() override = default; ~TraceIfExpFalseBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo {
class TraceCopy : public TraceInfo { class TraceCopy : public TraceInfo {
public: public:
TraceCopy() : TraceInfo(nullptr, "copy", "") {} TraceCopy() : TraceInfo(nullptr, "copy", "") {}
explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {} explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {}
MS_DECLARE_PARENT(TraceCopy, TraceInfo); MS_DECLARE_PARENT(TraceCopy, TraceInfo);
~TraceCopy() override = default; ~TraceCopy() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); } TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); }
@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo {
class TraceIterator : public TraceInfo { class TraceIterator : public TraceInfo {
public: public:
explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {} explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {}
MS_DECLARE_PARENT(TraceIterator, TraceInfo); MS_DECLARE_PARENT(TraceIterator, TraceInfo);
~TraceIterator() override = default; ~TraceIterator() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); } TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); }
@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo {
class TraceWhileHeader : public TraceInfo { class TraceWhileHeader : public TraceInfo {
public: public:
explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "") {} explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "") {}
MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo);
~TraceWhileHeader() override = default; ~TraceWhileHeader() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); } TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); }
@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo {
class TraceWhileBody : public TraceInfo { class TraceWhileBody : public TraceInfo {
public: public:
explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "") {} explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "") {}
MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); MS_DECLARE_PARENT(TraceWhileBody, TraceInfo);
~TraceWhileBody() override = default; ~TraceWhileBody() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); } TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); }
@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo {
class TraceWhileAfter : public TraceInfo { class TraceWhileAfter : public TraceInfo {
public: public:
explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "") {} explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "") {}
MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo);
~TraceWhileAfter() override = default; ~TraceWhileAfter() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); } TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); }
@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo {
class TraceForHeader : public TraceInfo { class TraceForHeader : public TraceInfo {
public: public:
explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "") {} explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "") {}
MS_DECLARE_PARENT(TraceForHeader, TraceInfo); MS_DECLARE_PARENT(TraceForHeader, TraceInfo);
~TraceForHeader() override = default; ~TraceForHeader() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); } TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); }
@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo {
class TraceForBody : public TraceInfo { class TraceForBody : public TraceInfo {
public: public:
explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "") {} explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "") {}
MS_DECLARE_PARENT(TraceForBody, TraceInfo); MS_DECLARE_PARENT(TraceForBody, TraceInfo);
~TraceForBody() override = default; ~TraceForBody() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); } TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); }
@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo {
class TraceForAfter : public TraceInfo { class TraceForAfter : public TraceInfo {
public: public:
explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "") {} explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "") {}
MS_DECLARE_PARENT(TraceForAfter, TraceInfo); MS_DECLARE_PARENT(TraceForAfter, TraceInfo);
~TraceForAfter() override = default; ~TraceForAfter() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); } TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); }
@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo {
class TraceEquiv : public TraceInfo { class TraceEquiv : public TraceInfo {
public: public:
explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {} explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {}
MS_DECLARE_PARENT(TraceEquiv, TraceInfo); MS_DECLARE_PARENT(TraceEquiv, TraceInfo);
~TraceEquiv() override = default; ~TraceEquiv() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); } TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); }
@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo {
class TraceGradFpropApp : public TraceInfo { class TraceGradFpropApp : public TraceInfo {
public: public:
TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "") {} TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "") {}
explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "") {} explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "") {}
MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo);
~TraceGradFpropApp() override = default; ~TraceGradFpropApp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); }
@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo {
class TraceGradBpropApp : public TraceInfo { class TraceGradBpropApp : public TraceInfo {
public: public:
TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "") {} TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "") {}
explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "") {} explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "") {}
MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo);
~TraceGradBpropApp() override = default; ~TraceGradBpropApp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); }
@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo {
class TraceGradFprop : public TraceInfo { class TraceGradFprop : public TraceInfo {
public: public:
TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "") {} TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "") {}
explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "") {} explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "") {}
MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); MS_DECLARE_PARENT(TraceGradFprop, TraceInfo);
~TraceGradFprop() override = default; ~TraceGradFprop() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); }
@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo {
class TraceGradBprop : public TraceInfo { class TraceGradBprop : public TraceInfo {
public: public:
TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "") {} TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "") {}
explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "") {} explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "") {}
MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); MS_DECLARE_PARENT(TraceGradBprop, TraceInfo);
~TraceGradBprop() override = default; ~TraceGradBprop() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); }
@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo {
class TraceGradSens : public TraceInfo { class TraceGradSens : public TraceInfo {
public: public:
TraceGradSens() : TraceInfo(nullptr, "grad_sens", "") {} TraceGradSens() : TraceInfo(nullptr, "grad_sens", "") {}
explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "") {} explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "") {}
MS_DECLARE_PARENT(TraceGradSens, TraceInfo); MS_DECLARE_PARENT(TraceGradSens, TraceInfo);
~TraceGradSens() override = default; ~TraceGradSens() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); }
@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo {
class TraceSpecialize : public TraceInfo { class TraceSpecialize : public TraceInfo {
public: public:
explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; }
MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); MS_DECLARE_PARENT(TraceSpecialize, TraceInfo);
std::string name() override { return full_name_ + counter_; } std::string name() override { return full_name_ + counter_; }
std::string symbol() override { return counter_ + "_"; } std::string symbol() override { return counter_ + "_"; }
@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo {
class TraceGradOperation : public TraceInfo { class TraceGradOperation : public TraceInfo {
public: public:
explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {} explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {}
MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); MS_DECLARE_PARENT(TraceGradOperation, TraceInfo);
~TraceGradOperation() override = default; ~TraceGradOperation() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo {
class TraceForceBool : public TraceInfo { class TraceForceBool : public TraceInfo {
public: public:
explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {} explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {}
MS_DECLARE_PARENT(TraceForceBool, TraceInfo); MS_DECLARE_PARENT(TraceForceBool, TraceInfo);
~TraceForceBool() override = default; ~TraceForceBool() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); } TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo {
class TraceExpandJ : public TraceInfo { class TraceExpandJ : public TraceInfo {
public: public:
explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {} explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {}
MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); MS_DECLARE_PARENT(TraceExpandJ, TraceInfo);
~TraceExpandJ() override = default; ~TraceExpandJ() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); } TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); }
@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo {
class TraceGenMetaFuncGraph : public TraceInfo { class TraceGenMetaFuncGraph : public TraceInfo {
public: public:
explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {} explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {}
MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo);
~TraceGenMetaFuncGraph() override = default; ~TraceGenMetaFuncGraph() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo {
class TraceEvaluatorGenGraph : public TraceInfo { class TraceEvaluatorGenGraph : public TraceInfo {
public: public:
explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {} explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {}
MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo);
~TraceEvaluatorGenGraph() override = default; ~TraceEvaluatorGenGraph() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo {
class TraceResolve : public TraceInfo { class TraceResolve : public TraceInfo {
public: public:
explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {} explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {}
MS_DECLARE_PARENT(TraceResolve, TraceInfo); MS_DECLARE_PARENT(TraceResolve, TraceInfo);
~TraceResolve() override = default; ~TraceResolve() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); } TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); }
@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo {
class TraceTransform : public TraceInfo { class TraceTransform : public TraceInfo {
public: public:
TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; }
explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") { explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") {
transform_name_ = transform_name; transform_name_ = transform_name;
} }
@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo {
class TraceGenerateVarArg : public TraceInfo { class TraceGenerateVarArg : public TraceInfo {
public: public:
explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {} explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {}
MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo);
~TraceGenerateVarArg() override = default; ~TraceGenerateVarArg() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo {
class TraceGenerateKwArg : public TraceInfo { class TraceGenerateKwArg : public TraceInfo {
public: public:
explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {} explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {}
MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo);
~TraceGenerateKwArg() override = default; ~TraceGenerateKwArg() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo {
class TraceTrasformK : public TraceInfo { class TraceTrasformK : public TraceInfo {
public: public:
explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {} explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {}
MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); MS_DECLARE_PARENT(TraceTrasformK, TraceInfo);
~TraceTrasformK() override = default; ~TraceTrasformK() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); } TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); }
@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo {
class TracePartialTransform : public TraceInfo { class TracePartialTransform : public TraceInfo {
public: public:
explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {} explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {}
MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); MS_DECLARE_PARENT(TracePartialTransform, TraceInfo);
~TracePartialTransform() override = default; ~TracePartialTransform() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo {
class TraceGetEnv : public TraceInfo { class TraceGetEnv : public TraceInfo {
public: public:
explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {} explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {}
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
~TraceGetEnv() override = default; ~TraceGetEnv() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); }
@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo {
class TraceDoSignature : public TraceInfo { class TraceDoSignature : public TraceInfo {
public: public:
explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {} explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {}
MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); MS_DECLARE_PARENT(TraceDoSignature, TraceInfo);
~TraceDoSignature() override = default; ~TraceDoSignature() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); } TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); }
@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo {
class TraceCombileLikeGraphs : public TraceInfo { class TraceCombileLikeGraphs : public TraceInfo {
public: public:
TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {}
explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {} explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {}
MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo);
~TraceCombileLikeGraphs() override = default; ~TraceCombileLikeGraphs() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {

View File

@ -21,7 +21,7 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (has_malloc_) { if (has_malloc_) {
MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; MS_LOG(EXCEPTION) << "Has alloc memory pool memory !";
} }
@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
return size; return size;
} }
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) { bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
MS_EXCEPTION_IF_NULL(addr); MS_EXCEPTION_IF_NULL(addr);
has_malloc_ = false; has_malloc_ = false;
free_mem_size_ = total_mem_size_; free_mem_size_ = total_mem_size_;
@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; }
void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) { void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) {
MS_EXCEPTION_IF_NULL(device_mem_pool_base); MS_EXCEPTION_IF_NULL(device_mem_pool_base);
device_mem_pool_base_ = device_mem_pool_base; device_mem_pool_base_ = device_mem_pool_base;
} }

View File

@ -26,12 +26,12 @@ namespace ascend {
class AscendMemoryPool : public DynamicMemPoolBestFit { class AscendMemoryPool : public DynamicMemPoolBestFit {
public: public:
~AscendMemoryPool() override = default; ~AscendMemoryPool() override = default;
AscendMemoryPool(const AscendMemoryPool&) = delete; AscendMemoryPool(const AscendMemoryPool &) = delete;
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; AscendMemoryPool &operator=(const AscendMemoryPool &) = delete;
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override; bool FreeDeviceMem(const DeviceMemPtr &addr) override;
void set_device_mem_pool_base(uint8_t* device_mem_pool_base); void set_device_mem_pool_base(uint8_t *device_mem_pool_base);
void set_device_mem_pool_size(uint64_t device_mem_pool_size) { void set_device_mem_pool_size(uint64_t device_mem_pool_size) {
device_mem_pool_size_ = device_mem_pool_size; device_mem_pool_size_ = device_mem_pool_size;
free_mem_size_ = device_mem_pool_size_; free_mem_size_ = device_mem_pool_size_;
@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
size_t free_mem_size() override; size_t free_mem_size() override;
size_t total_mem_size() override; size_t total_mem_size() override;
static AscendMemoryPool& GetInstance() { static AscendMemoryPool &GetInstance() {
static AscendMemoryPool instance; static AscendMemoryPool instance;
return instance; return instance;
} }
@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
private: private:
AscendMemoryPool() = default; AscendMemoryPool() = default;
bool has_malloc_{false}; bool has_malloc_{false};
uint8_t* device_mem_pool_base_{nullptr}; uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0}; uint64_t device_mem_pool_size_{0};
size_t free_mem_size_{0}; size_t free_mem_size_{0};
size_t total_mem_size_{0}; size_t total_mem_size_{0};

View File

@ -39,13 +39,13 @@ using std::vector;
class AscendStreamAssign { class AscendStreamAssign {
public: public:
static AscendStreamAssign& GetInstance() { static AscendStreamAssign &GetInstance() {
static AscendStreamAssign instance; // Guaranteed to be destroyed. static AscendStreamAssign instance; // Guaranteed to be destroyed.
return instance; return instance;
} }
AscendStreamAssign(const AscendStreamAssign&) = delete; AscendStreamAssign(const AscendStreamAssign &) = delete;
AscendStreamAssign& operator=(const AscendStreamAssign&) = delete; AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
uint32_t GetTotalStreamNum() const; uint32_t GetTotalStreamNum() const;
// new stream policy // new stream policy
@ -53,19 +53,19 @@ class AscendStreamAssign {
uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
uint32_t total_event_num() const { return total_event_num_; } uint32_t total_event_num() const { return total_event_num_; }
void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr); void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ResetNew(); void ResetNew();
void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool IsIndependentNode(const CNodePtr& node_ptr); bool IsIndependentNode(const CNodePtr &node_ptr);
const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; } const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; }
const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; } const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; }
const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; } const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; }
void GetWaitStreams(vector<uint32_t>* wait_active_stream_list); void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; } const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; }
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id); uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id); uint32_t stream_id);
private: private:
@ -73,30 +73,30 @@ class AscendStreamAssign {
~AscendStreamAssign() = default; ~AscendStreamAssign() = default;
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
const CNodePtr& node); const CNodePtr &node);
bool IsHcom(const CNodePtr& apply_kernel); bool IsHcom(const CNodePtr &apply_kernel);
bool IsProcessed(uint32_t logic_id); bool IsProcessed(uint32_t logic_id);
void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids); void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids);
void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index,
uint32_t* cur_stream_id); uint32_t *cur_stream_id);
void RecordIdMap(uint32_t logic_id, uint32_t physic_id); void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
void UpdateStreamActive(const CNodePtr& active_ptr); void UpdateStreamActive(const CNodePtr &active_ptr);
void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr); void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr);
bool IsTaskSink(); bool IsTaskSink();
void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id);
void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr); void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr); void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr); void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr);
void SetCommonStreamNum(uint32_t cur_stream_id); void SetCommonStreamNum(uint32_t cur_stream_id);
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool IsProcessedParallelStream(uint32_t stream_id); bool IsProcessedParallelStream(uint32_t stream_id);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams); void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr); void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
uint32_t total_common_stream_num_{0}; uint32_t total_common_stream_num_{0};
uint32_t total_independ_stream_num_{0}; uint32_t total_independ_stream_num_{0};

View File

@ -28,14 +28,14 @@ namespace device {
namespace ascend { namespace ascend {
class PluginImpl : public PluginIntf { class PluginImpl : public PluginIntf {
public: public:
explicit PluginImpl(const std::string& module); explicit PluginImpl(const std::string &module);
~PluginImpl() override = default; ~PluginImpl() override = default;
int Init(const Reporter* reporter) override; int Init(const Reporter *reporter) override;
int UnInit() override; int UnInit() override;
static Reporter* GetPluginReporter() { return reporter_; } static Reporter *GetPluginReporter() { return reporter_; }
private: private:
static Reporter* reporter_; static Reporter *reporter_;
std::string module_; std::string module_;
}; };
} // namespace ascend } // namespace ascend

View File

@ -20,12 +20,12 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
PluginIntf* ProfilingEngineImpl::CreatePlugin() { PluginIntf *ProfilingEngineImpl::CreatePlugin() {
MS_LOG(INFO) << "Create Plugin."; MS_LOG(INFO) << "Create Plugin.";
return new (std::nothrow) PluginImpl("Framework"); return new (std::nothrow) PluginImpl("Framework");
} }
int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) { int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) {
if (plugin != nullptr) { if (plugin != nullptr) {
delete plugin; delete plugin;
} }

View File

@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf {
ProfilingEngineImpl() = default; ProfilingEngineImpl() = default;
~ProfilingEngineImpl() override = default; ~ProfilingEngineImpl() override = default;
PluginIntf* CreatePlugin() override; PluginIntf *CreatePlugin() override;
int ReleasePlugin(PluginIntf* plugin) override; int ReleasePlugin(PluginIntf *plugin) override;
}; };
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device

View File

@ -35,7 +35,7 @@ using Json = nlohmann::json;
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
ProfilingManager& ProfilingManager::GetInstance() { ProfilingManager &ProfilingManager::GetInstance() {
static ProfilingManager inst; static ProfilingManager inst;
return inst; return inst;
} }
@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) {
} }
uint64_t ProfilingManager::GetJobId() const { uint64_t ProfilingManager::GetJobId() const {
const char* job_id = std::getenv("JOB_ID"); const char *job_id = std::getenv("JOB_ID");
return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0);
} }
bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskId_map) const { bool ProfilingManager::ReportProfilingData(const map<uint32_t, string> &op_taskId_map) const {
if (!IsProfiling()) { if (!IsProfiling()) {
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
return false; return false;
@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size();
Msprof::Engine::ReporterData reporter_data = {}; Msprof::Engine::ReporterData reporter_data = {};
for (const auto& iter : op_taskId_map) { for (const auto &iter : op_taskId_map) {
auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; auto data = iter.second + ' ' + std::to_string(iter.first) + ';';
reporter_data.deviceId = UintToInt(device_id_); reporter_data.deviceId = UintToInt(device_id_);
reporter_data.data = (unsigned char*)(const_cast<char*>(data.c_str())); reporter_data.data = (unsigned char *)(const_cast<char *>(data.c_str()));
reporter_data.dataLen = data.size(); reporter_data.dataLen = data.size();
auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework"));
if (ret != 0) { if (ret != 0) {
@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
return true; return true;
} }
static std::vector<std::string> Split(const std::string& str, const char delim) { static std::vector<std::string> Split(const std::string &str, const char delim) {
std::vector<std::string> elems; std::vector<std::string> elems;
if (str.empty()) { if (str.empty()) {
@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
device_id_ = device_id; device_id_ = device_id;
// exp: export PROFILING_MODE=true // exp: export PROFILING_MODE=true
// export PROFILING_OPTIONS=training_trace // export PROFILING_OPTIONS=training_trace
const char* prof_options_str = std::getenv("PROFILING_OPTIONS"); const char *prof_options_str = std::getenv("PROFILING_OPTIONS");
// register Framework to profiling // register Framework to profiling
int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get());
if (result != 0) { if (result != 0) {
@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const {
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
return true; return true;
} }
Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter(); Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
if (reporter != nullptr) { if (reporter != nullptr) {
MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); MS_LOG(INFO) << "report data end, ret = " << reporter->Flush();
} }

View File

@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,
class GpuQueue { class GpuQueue {
public: public:
GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity); GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity);
virtual ~GpuQueue(); virtual ~GpuQueue();
void RegisterRelease(const std::function<void(void*)>& func) { host_release_ = func; } void RegisterRelease(const std::function<void(void *)> &func) { host_release_ = func; }
inline bool IsEmpty() const { return head_ == tail_; } inline bool IsEmpty() const { return head_ == tail_; }
inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); } inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); }
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size); BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size);
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const; BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const;
BlockQueueStatus_T Pop(); BlockQueueStatus_T Pop();
bool Destroy(); bool Destroy();
private: private:
struct NodeInfo { struct NodeInfo {
std::unique_ptr<cudaEvent_t> event_; std::unique_ptr<cudaEvent_t> event_;
void* host_feature_addr_; void *host_feature_addr_;
void* host_label_addr_; void *host_label_addr_;
}; };
void* buffer_; void *buffer_;
size_t head_; size_t head_;
size_t tail_; size_t tail_;
size_t feature_size_; size_t feature_size_;
@ -61,10 +61,10 @@ class GpuQueue {
size_t capacity_; size_t capacity_;
cudaStream_t stream_; cudaStream_t stream_;
std::unique_ptr<NodeInfo[]> node_info_; std::unique_ptr<NodeInfo[]> node_info_;
std::function<void(void*)> host_release_; std::function<void(void *)> host_release_;
GpuQueue(const GpuQueue&) = delete; GpuQueue(const GpuQueue &) = delete;
GpuQueue& operator=(const GpuQueue&) = delete; GpuQueue &operator=(const GpuQueue &) = delete;
}; };
class BlockingQueue { class BlockingQueue {
@ -72,11 +72,11 @@ class BlockingQueue {
BlockingQueue() : queue_(nullptr) {} BlockingQueue() : queue_(nullptr) {}
~BlockingQueue() = default; ~BlockingQueue() = default;
BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity); BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity);
void RegisterRelease(const std::function<void(void*)>& func); void RegisterRelease(const std::function<void(void *)> &func);
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size, BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size,
unsigned int timeout_in_sec); unsigned int timeout_in_sec);
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size); BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size);
BlockQueueStatus_T Pop(); BlockQueueStatus_T Pop();
bool Destroy(); bool Destroy();

View File

@ -20,17 +20,17 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
CollectiveInitializer& CollectiveInitializer::instance() { CollectiveInitializer &CollectiveInitializer::instance() {
static CollectiveInitializer instance = {}; static CollectiveInitializer instance = {};
return instance; return instance;
} }
bool CollectiveInitializer::collective_inited() const { return collective_inited_; } bool CollectiveInitializer::collective_inited() const { return collective_inited_; }
const void* CollectiveInitializer::collective_handle() const { return collective_handle_; } const void *CollectiveInitializer::collective_handle() const { return collective_handle_; }
void CollectiveInitializer::InitCollective() { void CollectiveInitializer::InitCollective() {
void* handle = dlopen("libgpu_collective.so", RTLD_LAZY); void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
if (handle == nullptr) { if (handle == nullptr) {
MS_LOG(EXCEPTION) MS_LOG(EXCEPTION)
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not "

View File

@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() {
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
} }
bool GPUDeviceManager::CreateStream(DeviceStream* stream) { bool GPUDeviceManager::CreateStream(DeviceStream *stream) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream");
gpu_streams_.emplace_back(*stream); gpu_streams_.emplace_back(*stream);
return true; return true;
} }
const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; } const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; }
int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); }
@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; }
bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; }
const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); } bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); }
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const { bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const {
return CudaDriver::CopyDeviceMemToHost(dst, src, size); return CudaDriver::CopyDeviceMemToHost(dst, src, size);
} }
bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const { bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const {
return CudaDriver::CopyHostMemToDevice(dst, src, size); return CudaDriver::CopyHostMemToDevice(dst, src, size);
} }
} // namespace gpu } // namespace gpu

View File

@ -37,17 +37,17 @@ class GPUDeviceManager {
uint32_t cur_device_id() const; uint32_t cur_device_id() const;
bool is_device_id_init() const; bool is_device_id_init() const;
bool CreateStream(DeviceStream* stream); bool CreateStream(DeviceStream *stream);
bool SyncStream(const DeviceStream& stream) const; bool SyncStream(const DeviceStream &stream) const;
const DeviceStream& default_stream() const; const DeviceStream &default_stream() const;
const cudnnHandle_t& GetCudnnHandle() const; const cudnnHandle_t &GetCudnnHandle() const;
const cublasHandle_t& GetCublasHandle() const; const cublasHandle_t &GetCublasHandle() const;
bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const; bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const; bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const;
static GPUDeviceManager& GetInstance() { static GPUDeviceManager &GetInstance() {
static GPUDeviceManager instance; static GPUDeviceManager instance;
return instance; return instance;
} }
@ -55,8 +55,8 @@ class GPUDeviceManager {
private: private:
GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {}
~GPUDeviceManager() = default; ~GPUDeviceManager() = default;
GPUDeviceManager(const GPUDeviceManager&) = delete; GPUDeviceManager(const GPUDeviceManager &) = delete;
GPUDeviceManager& operator=(const GPUDeviceManager&) = delete; GPUDeviceManager &operator=(const GPUDeviceManager &) = delete;
// default CUDA stream used for all the kernels. // default CUDA stream used for all the kernels.
DeviceStream default_stream_{nullptr}; DeviceStream default_stream_{nullptr};

View File

@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() {
return true; return true;
} }
bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) { bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) {
auto alloc_size = AllocDeviceMem(size, addr); auto alloc_size = AllocDeviceMem(size, addr);
buffer_q_addr_ = *addr; buffer_q_addr_ = *addr;
// Buffer queue needs to ensure that the alloc_size and size is equal. // Buffer queue needs to ensure that the alloc_size and size is equal.
return (alloc_size == size) ? true : false; return (alloc_size == size) ? true : false;
} }
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (size == 0) { if (size == 0) {
MS_LOG(EXCEPTION) << "The memory alloc size is 0."; MS_LOG(EXCEPTION) << "The memory alloc size is 0.";
} }
@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
return alloc_size; return alloc_size;
} }
bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); } bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); }
size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); } size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); }

View File

@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit {
~GPUMemoryAllocator() override = default; ~GPUMemoryAllocator() override = default;
bool Init(); bool Init();
bool Finalize(); bool Finalize();
bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr); bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr);
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override; bool FreeDeviceMem(const DeviceMemPtr &addr) override;
size_t free_mem_size() override; size_t free_mem_size() override;
size_t total_mem_size() override; size_t total_mem_size() override;
static GPUMemoryAllocator& GetInstance() { static GPUMemoryAllocator &GetInstance() {
static GPUMemoryAllocator instance; static GPUMemoryAllocator instance;
return instance; return instance;
} }
private: private:
GPUMemoryAllocator() = default; GPUMemoryAllocator() = default;
GPUMemoryAllocator(const GPUMemoryAllocator&) = delete; GPUMemoryAllocator(const GPUMemoryAllocator &) = delete;
GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete; GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete;
// Used to track address of data buffer queue. // Used to track address of data buffer queue.
DeviceMemPtr buffer_q_addr_{nullptr}; DeviceMemPtr buffer_q_addr_{nullptr};

View File

@ -33,8 +33,8 @@ namespace gpu {
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using mindspore::kernel::KernelBuildInfo; using mindspore::kernel::KernelBuildInfo;
namespace { namespace {
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info, bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) { const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
MS_EXCEPTION_IF_NULL(selected_kernel_info); MS_EXCEPTION_IF_NULL(selected_kernel_info);
MS_EXCEPTION_IF_NULL(alternative_kernel_info); MS_EXCEPTION_IF_NULL(alternative_kernel_info);
size_t selected_input_num = selected_kernel_info->GetInputNum(); size_t selected_input_num = selected_kernel_info->GetInputNum();
@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_
return true; return true;
} }
std::string SupportedTypeList(const CNodePtr& kernel_node) { std::string SupportedTypeList(const CNodePtr &kernel_node) {
std::string supported_type_lists = std::string supported_type_lists =
kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node));
if (!supported_type_lists.empty()) { if (!supported_type_lists.empty()) {
@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) {
return supported_type_lists; return supported_type_lists;
} }
bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) { bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(selected_kernel_info); MS_EXCEPTION_IF_NULL(selected_kernel_info);
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
} }
bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&](const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info) { [&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) {
return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); return CheckKernelInfo(alternative_kernel_info, selected_kernel_info);
}); });
if (!match) { if (!match) {
@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
return true; return true;
} }
void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, const CNodePtr& kernel_node) { void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = kernel_node->input(input_index + 1); auto input_kernel_node = kernel_node->input(input_index + 1);
@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co
} }
} // namespace } // namespace
void SetKernelInfo(const CNodePtr& kernel_node) { void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format; std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type; std::vector<TypeId> inputs_type;
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =

View File

@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
void SetKernelInfo(const CNodePtr& apply_kernel_ptr); void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
class KernelAttr { class KernelAttr {
public: public:
@ -35,24 +35,24 @@ class KernelAttr {
KernelAttr() : all_same_(false) {} KernelAttr() : all_same_(false) {}
~KernelAttr() = default; ~KernelAttr() = default;
KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
input_type_.emplace_back(ms_type, format); input_type_.emplace_back(ms_type, format);
return *this; return *this;
} }
KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
output_type_.emplace_back(ms_type, format); output_type_.emplace_back(ms_type, format);
return *this; return *this;
} }
KernelAttr& AddAllSameAttr(const bool& all_same) { KernelAttr &AddAllSameAttr(const bool &all_same) {
all_same_ = all_same; all_same_ = all_same;
return *this; return *this;
} }
const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; } const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; } const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
const bool& GetAllSame() const { return all_same_; } const bool &GetAllSame() const { return all_same_; }
size_t GetInputSize() const { return input_type_.size(); } size_t GetInputSize() const { return input_type_.size(); }
size_t GetOutputSize() const { return output_type_.size(); } size_t GetOutputSize() const { return output_type_.size(); }

View File

@ -24,7 +24,7 @@
namespace mindspore { namespace mindspore {
struct TypeIdManager* TypeIdManager::Get() { struct TypeIdManager *TypeIdManager::Get() {
static TypeIdManager manager; static TypeIdManager manager;
return &manager; return &manager;
} }

View File

@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
std::string AnfNode::ToString() const { std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode*>(this)->shared_from_base<AnfNode>()->debug_info()); return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
} }
CNode::CNode(const std::vector<AnfNodePtr>& inputs, const FuncGraphPtr& func_graph) CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
// Check if CNode is an apply with the specific Primitive. // Check if CNode is an apply with the specific Primitive.
bool CNode::IsApply(const PrimitivePtr& value) const { bool CNode::IsApply(const PrimitivePtr &value) const {
if (value == nullptr) { if (value == nullptr) {
return false; return false;
} }
@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const {
return false; return false;
} }
void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; } void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; }
std::string CNode::DebugString(int recursive_level) const { std::string CNode::DebugString(int recursive_level) const {
std::ostringstream buffer; std::ostringstream buffer;
@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const {
buffer << ToString() << "{"; buffer << ToString() << "{";
bool is_first_node = true; bool is_first_node = true;
int idx = 0; int idx = 0;
for (auto& node : inputs_) { for (auto &node : inputs_) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (is_first_node) { if (is_first_node) {
is_first_node = false; is_first_node = false;
@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const {
return buffer.str(); return buffer.str();
} }
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) { OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) { if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name(); << ", using the new one: " << operator_info->name();
@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() {
return fullname_with_scope_; return fullname_with_scope_;
} }
void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<CNode>()); } void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<ValueNode>()); } void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base<Parameter>()); } void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr) { if (cnode != nullptr) {
@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) {
return false; return false;
} }
PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) { PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return nullptr; return nullptr;
} }
@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) {
return ""; return "";
} }
bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
if (IsValueNode<Primitive>(node)) { if (IsValueNode<Primitive>(node)) {
PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node); PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
} }
namespace id_generator { namespace id_generator {
static std::unordered_map<std::string, int> node_ids; static std::unordered_map<std::string, int> node_ids;
std::string get_id(const AnfNodePtr& node) { std::string get_id(const AnfNodePtr &node) {
auto type_name = node->type_name(); auto type_name = node->type_name();
if (node_ids.find(type_name) == node_ids.end()) { if (node_ids.find(type_name) == node_ids.end()) {
node_ids[type_name] = 0; node_ids[type_name] = 0;

View File

@ -39,15 +39,15 @@ struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
class Base : public std::enable_shared_from_this<Base> { class Base : public std::enable_shared_from_this<Base> {
public: public:
constexpr Base() = default; constexpr Base() = default;
Base(const Base& other) : std::enable_shared_from_this<Base>(other) {} Base(const Base &other) : std::enable_shared_from_this<Base>(other) {}
virtual bool operator==(const Base& rhs) { virtual bool operator==(const Base &rhs) {
if (this == &rhs) { if (this == &rhs) {
return true; return true;
} }
return false; return false;
} }
virtual Base& operator=(const Base&) { return *this; } virtual Base &operator=(const Base &) { return *this; }
virtual ~Base() = default; virtual ~Base() = default;
virtual std::size_t hash() const { return tid(); } virtual std::size_t hash() const { return tid(); }
virtual std::string ToString() const { return type_name(); } virtual std::string ToString() const { return type_name(); }
@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this<Base> {
virtual const bool IsFromTypeId(uint32_t tid) const; virtual const bool IsFromTypeId(uint32_t tid) const;
virtual std::string type_name() const { return "Base"; } virtual std::string type_name() const { return "Base"; }
static uint32_t GetTypeId(const char* const type_key); static uint32_t GetTypeId(const char *const type_key);
virtual uint32_t tid() const { virtual uint32_t tid() const {
static const uint32_t tid = GetTypeId(typeid(Base).name()); static const uint32_t tid = GetTypeId(typeid(Base).name());
return tid; return tid;
} }
template <typename T, template <typename T,
typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type* = nullptr> typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type * = nullptr>
inline bool isa() const { inline bool isa() const {
static const uint32_t tid = GetTypeId(typeid(T).name()); static const uint32_t tid = GetTypeId(typeid(T).name());
return this->IsFromTypeId(tid); return this->IsFromTypeId(tid);
@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr<Base>;
using BaseWeakPtr = std::weak_ptr<Base>; using BaseWeakPtr = std::weak_ptr<Base>;
template <typename T, typename U> template <typename T, typename U>
inline T* cast(U* source) { inline T *cast(U *source) {
if (source != nullptr && source->template isa<T>()) { if (source != nullptr && source->template isa<T>()) {
return static_cast<T*>(source); return static_cast<T *>(source);
} else { } else {
return nullptr; return nullptr;
} }
@ -100,7 +100,7 @@ inline T* cast(U* source) {
template < template <
typename T, typename U, typename T, typename U,
typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type* = nullptr> typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type * = nullptr>
inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) { inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) {
if (r != nullptr && r->template isa<T>()) { if (r != nullptr && r->template isa<T>()) {
return std::static_pointer_cast<T>(r); return std::static_pointer_cast<T>(r);
@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager {
std::mutex mutex; std::mutex mutex;
std::atomic<uint32_t> type_counter{0}; std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> map; std::unordered_map<std::string, uint32_t> map;
static TypeIdManager* Get(); static TypeIdManager *Get();
TypeIdManager() : mutex(), type_counter(0), map() {} TypeIdManager() : mutex(), type_counter(0), map() {}
}; };
} // namespace mindspore } // namespace mindspore

View File

@ -48,11 +48,11 @@ std::string Keyword::ToString() const {
return buffer.str(); return buffer.str();
} }
bool Keyword::operator==(const Type& other) const { bool Keyword::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
const auto& other_keyword = static_cast<const Keyword&>(other); const auto &other_keyword = static_cast<const Keyword &>(other);
return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_); return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_);
} }
@ -87,11 +87,11 @@ std::string Slice::ToString() const {
return buffer.str(); return buffer.str();
} }
bool Slice::operator==(const Type& other) const { bool Slice::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_slice = static_cast<const Slice&>(other); auto other_slice = static_cast<const Slice &>(other);
return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_); return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_);
} }
@ -122,11 +122,11 @@ std::string TensorType::DumpText() const {
} }
} }
bool TensorType::operator==(const Type& other) const { bool TensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_elem_type = static_cast<const TensorType&>(other).element_type_; auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
// When element_type_ = nullptr, which means any type of Array. // When element_type_ = nullptr, which means any type of Array.
if (element_type_ == nullptr && other_elem_type == nullptr) { if (element_type_ == nullptr && other_elem_type == nullptr) {
return true; return true;
@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) {
retval_ = nullptr; retval_ = nullptr;
} }
Function::Function(const std::vector<TypePtr>& args, const TypePtr retval) Function::Function(const std::vector<TypePtr> &args, const TypePtr retval)
: Object(kObjectTypeFunction, false), args_(args), retval_(retval) {} : Object(kObjectTypeFunction, false), args_(args), retval_(retval) {}
TypePtr Function::DeepCopy() const { TypePtr Function::DeepCopy() const {
@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const {
TypePtrList args; TypePtrList args;
TypePtr retval = nullptr; TypePtr retval = nullptr;
(void)std::transform(args_.begin(), args_.end(), std::back_inserter(args), (void)std::transform(args_.begin(), args_.end(), std::back_inserter(args),
[](const TypePtr& arg) { return arg->DeepCopy(); }); [](const TypePtr &arg) { return arg->DeepCopy(); });
if (retval_ != nullptr) { if (retval_ != nullptr) {
retval = retval_->DeepCopy(); retval = retval_->DeepCopy();
} }
@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const {
} }
} }
bool Function::operator==(const Type& other) const { bool Function::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
const auto& other_function = static_cast<const Function&>(other); const auto &other_function = static_cast<const Function &>(other);
if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) { if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) {
if (*retval_ != *other_function.retval_) { if (*retval_ != *other_function.retval_) {
return false; return false;
@ -188,7 +188,7 @@ std::string Function::ToString() const {
} else { } else {
buffer << "Func[("; buffer << "Func[(";
bool begin = true; bool begin = true;
for (auto& attr : args_) { for (auto &attr : args_) {
if (!begin) { if (!begin) {
buffer << ", "; buffer << ", ";
} else { } else {
@ -242,34 +242,34 @@ std::string JTagged::DumpText() const {
return buffer.str(); return buffer.str();
} }
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem) { std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem) {
MS_EXCEPTION_IF_NULL(problem); MS_EXCEPTION_IF_NULL(problem);
os << problem->ToString(); os << problem->ToString();
return os; return os;
} }
std::size_t TypeHasher::operator()(TypePtr const& type) const { std::size_t TypeHasher::operator()(TypePtr const &type) const {
MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(type);
std::size_t hash = std::hash<size_t>()(type->type_id()); std::size_t hash = std::hash<size_t>()(type->type_id());
return hash; return hash;
} }
std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const { std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
std::size_t hash_sum = 0; std::size_t hash_sum = 0;
for (auto& type : type_list) { for (auto &type : type_list) {
auto type_id = static_cast<std::size_t>(type->type_id()); auto type_id = static_cast<std::size_t>(type->type_id());
hash_sum = hash_combine(hash_sum, type_id); hash_sum = hash_combine(hash_sum, type_id);
} }
return hash_sum; return hash_sum;
} }
bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const { bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2); MS_EXCEPTION_IF_NULL(t2);
return t1->type_id() == t2->type_id(); return t1->type_id() == t2->type_id();
} }
bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const { bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
if (lhs.size() != rhs.size()) { if (lhs.size() != rhs.size()) {
return false; return false;
} }
@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) {
namespace { namespace {
template <typename T> template <typename T>
TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) { TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == num_type_name) { if (type_name == num_type_name) {
type = std::make_shared<T>(); type = std::make_shared<T>();
@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_
} }
auto bits = std::stoi(type_name.substr(num_type_name.size())); auto bits = std::stoi(type_name.substr(num_type_name.size()));
type = std::make_shared<T>(bits); type = std::make_shared<T>(bits);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what();
} }
} }
return type; return type;
} }
std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) { std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
std::vector<TypePtr> types; std::vector<TypePtr> types;
if (type_names.length() == 0) { if (type_names.length() == 0) {
return types; return types;
@ -371,7 +371,7 @@ std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) {
return types; return types;
} }
TypePtr TensorStrToType(const std::string& type_name) { TypePtr TensorStrToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == "Tensor") { if (type_name == "Tensor") {
type = std::make_shared<TensorType>(); type = std::make_shared<TensorType>();
@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
return nullptr; return nullptr;
} }
type = std::make_shared<TensorType>(element_type); type = std::make_shared<TensorType>(element_type);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
} }
} }
@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
return type; return type;
} }
TypePtr ListStrToType(const std::string& type_name) { TypePtr ListStrToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == "List") { if (type_name == "List") {
type = std::make_shared<List>(); type = std::make_shared<List>();
@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) {
std::string element_strs = type_name.substr(start, end - start); std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong = bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) { if (wrong) {
return nullptr; return nullptr;
} }
type = std::make_shared<List>(element_types); type = std::make_shared<List>(element_types);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
} }
} }
@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) {
return type; return type;
} }
TypePtr TupleStrToType(const std::string& type_name) { TypePtr TupleStrToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == "Tuple") { if (type_name == "Tuple") {
type = std::make_shared<Tuple>(); type = std::make_shared<Tuple>();
@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) {
std::string element_strs = type_name.substr(start, end - start); std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong = bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) { if (wrong) {
return nullptr; return nullptr;
} }
type = std::make_shared<Tuple>(element_types); type = std::make_shared<Tuple>(element_types);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
} }
} }
return type; return type;
} }
TypePtr FunctionStrToType(const std::string& type_name) { TypePtr FunctionStrToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == "Function") { if (type_name == "Function") {
@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) {
std::vector<TypePtr> args_type = StringToVectorOfType(str_args); std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
TypePtr retval = StringToType(str_retval); TypePtr retval = StringToType(str_retval);
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; }); bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
if (retval == nullptr || wrong) { if (retval == nullptr || wrong) {
return nullptr; return nullptr;
} }
type = std::make_shared<Function>(args_type, retval); type = std::make_shared<Function>(args_type, retval);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
} }
} }
@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) {
} }
} // namespace } // namespace
TypePtr StringToType(const std::string& type_name) { TypePtr StringToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name.compare("None") == 0) { if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>(); type = std::make_shared<TypeNone>();
@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) {
return type; return type;
} }
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
if (x == nullptr || base_type == nullptr) { if (x == nullptr || base_type == nullptr) {
MS_LOG(ERROR) << "Type is nullptr."; MS_LOG(ERROR) << "Type is nullptr.";
return false; return false;
@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) {
} }
} }
bool IsSubType(TypePtr const& t1, TypePtr const& t2) { bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t1);
if (t1->type_id() == kTypeUnknown) { if (t1->type_id() == kTypeUnknown) {
return false; return false;
@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) {
} }
REGISTER_PYBIND_DEFINE( REGISTER_PYBIND_DEFINE(
typing, ([](py::module* const m) { typing, ([](py::module *const m) {
auto m_sub = m->def_submodule("typing", "submodule for dtype"); auto m_sub = m->def_submodule("typing", "submodule for dtype");
py::enum_<TypeId>(m_sub, "TypeId"); py::enum_<TypeId>(m_sub, "TypeId");
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
(void)m_sub.def("load_type", &TypeIdToType, "load type"); (void)m_sub.def("load_type", &TypeIdToType, "load type");
(void)m_sub.def( (void)m_sub.def(
"dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type"); "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
.def("__eq__", .def("__eq__",
[](const TypePtr& t1, const TypePtr& t2) { [](const TypePtr &t1, const TypePtr &t2) {
if (t1 != nullptr && t2 != nullptr) { if (t1 != nullptr && t2 != nullptr) {
return *t1 == *t2; return *t1 == *t2;
} }
@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE(
.def("__hash__", &Type::hash) .def("__hash__", &Type::hash)
.def("__str__", &Type::ToString) .def("__str__", &Type::ToString)
.def("__repr__", &Type::ReprString) .def("__repr__", &Type::ReprString)
.def("__deepcopy__", [](const TypePtr& t, py::dict) { .def("__deepcopy__", [](const TypePtr &t, py::dict) {
if (t == nullptr) { if (t == nullptr) {
return static_cast<TypePtr>(nullptr); return static_cast<TypePtr>(nullptr);
} }
@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") (void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
.def(py::init()) .def(py::init())
.def(py::pickle( .def(py::pickle(
[](const Bool&) { // __getstate__ [](const Bool &) { // __getstate__
return py::make_tuple(); return py::make_tuple();
}, },
[](const py::tuple&) { // __setstate__ [](const py::tuple &) { // __setstate__
return std::make_shared<Bool>(); return std::make_shared<Bool>();
})); }));
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") (void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(
[](const Int& t) { // __getstate__ [](const Int &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits())); return py::make_tuple(py::int_(t.nbits()));
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(
[](const UInt& t) { // __getstate__ [](const UInt &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits())); return py::make_tuple(py::int_(t.nbits()));
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(
[](const Float& t) { // __getstate__ [](const Float &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits())); return py::make_tuple(py::int_(t.nbits()));
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init<TypePtr>(), py::arg("element")) .def(py::init<TypePtr>(), py::arg("element"))
.def("element_type", &TensorType::element) .def("element_type", &TensorType::element)
.def(py::pickle( .def(py::pickle(
[](const TensorType& t) { // __getstate__ [](const TensorType &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id()))); return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }

View File

@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr<String>;
class Keyword : public Object { class Keyword : public Object {
public: public:
Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {}
Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
~Keyword() override = default; ~Keyword() override = default;
MS_DECLARE_PARENT(Keyword, Object) MS_DECLARE_PARENT(Keyword, Object)
@ -70,7 +70,7 @@ class Keyword : public Object {
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
std::string GetKey() const { return key_; } std::string GetKey() const { return key_; }
TypePtr GetValue() const { return value_; } TypePtr GetValue() const { return value_; }
@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr<Keyword>;
class Slice : public Object { class Slice : public Object {
public: public:
Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {}
Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step) Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step)
: Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {}
~Slice() override = default; ~Slice() override = default;
@ -95,7 +95,7 @@ class Slice : public Object {
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtr get_start() const { return start_; } TypePtr get_start() const { return start_; }
TypePtr get_stop() const { return stop_; } TypePtr get_stop() const { return stop_; }
@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr<Slice>;
class TensorType : public Object { class TensorType : public Object {
public: public:
TensorType() : Object(kObjectTypeTensorType) {} TensorType() : Object(kObjectTypeTensorType) {}
explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
~TensorType() override = default; ~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object) MS_DECLARE_PARENT(TensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeTensorType; } TypeId generic_type_id() const override { return kObjectTypeTensorType; }
const TypePtr element() const { return element_type_; } const TypePtr element() const { return element_type_; }
void set_element(const TypePtr& element_type) { element_type_ = element_type; } void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
std::string ToString() const override; std::string ToString() const override;
std::string ToReprString() const override { return "tensor"; } std::string ToReprString() const override { return "tensor"; }
std::string DumpText() const override; std::string DumpText() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
private: private:
TypePtr element_type_; TypePtr element_type_;
@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr<TensorType>;
class Function : public Object { class Function : public Object {
public: public:
Function(); Function();
Function(const std::vector<TypePtr>& args, const TypePtr retval); Function(const std::vector<TypePtr> &args, const TypePtr retval);
~Function() override = default; ~Function() override = default;
MS_DECLARE_PARENT(Function, Object) MS_DECLARE_PARENT(Function, Object)
@ -141,11 +141,11 @@ class Function : public Object {
// Add temporarily for return abstraction to avoid type checking. // Add temporarily for return abstraction to avoid type checking.
bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); }
const std::vector<TypePtr>& args() const { return args_; } const std::vector<TypePtr> &args() const { return args_; }
const TypePtr& retval() const { return retval_; } const TypePtr &retval() const { return retval_; }
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
std::string ToString() const override; std::string ToString() const override;
std::string ToReprString() const override { return "function"; } std::string ToReprString() const override { return "function"; }
@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr<Function>;
class JTagged : public Object { class JTagged : public Object {
public: public:
JTagged() : Object(kObjectTypeJTagged) {} JTagged() : Object(kObjectTypeJTagged) {}
explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
~JTagged() override = default; ~JTagged() override = default;
MS_DECLARE_PARENT(JTagged, Object) MS_DECLARE_PARENT(JTagged, Object)
@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr<TypeType>;
class Problem : public Type { class Problem : public Type {
public: public:
Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {}
explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {} explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {}
~Problem() override = default; ~Problem() override = default;
MS_DECLARE_PARENT(Problem, Type) MS_DECLARE_PARENT(Problem, Type)
@ -222,7 +222,7 @@ class Problem : public Type {
std::string ToString() const override { return kind_.name(); } std::string ToString() const override { return kind_.name(); }
std::string DumpText() const override { return "ProblemType"; } std::string DumpText() const override { return "ProblemType"; }
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem); friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem);
private: private:
Named kind_; Named kind_;
@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr<External>;
// helper template // helper template
template <class T> template <class T>
TypePtr Clone(const T& t) { TypePtr Clone(const T &t) {
return t.Clone(); return t.Clone();
} }
TypePtr StringToType(const std::string& type_name); TypePtr StringToType(const std::string &type_name);
// Judge whether x is predicate or is a subclass of predicate. // Judge whether x is predicate or is a subclass of predicate.
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type); bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
// Whether t1 is identity or a subclass of t2. // Whether t1 is identity or a subclass of t2.
bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr); bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
struct TypeHasher { struct TypeHasher {
std::size_t operator()(TypePtr const& type) const; std::size_t operator()(TypePtr const &type) const;
}; };
struct TypeListHasher { struct TypeListHasher {
std::size_t operator()(const TypePtrList& type_list) const; std::size_t operator()(const TypePtrList &type_list) const;
}; };
struct TypeEqual { struct TypeEqual {
bool operator()(TypePtr const& t1, TypePtr const& t2) const; bool operator()(TypePtr const &t1, TypePtr const &t2) const;
}; };
struct TypeListEqual { struct TypeListEqual {
bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const; bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const;
}; };
extern const TypePtr kTypeExternal; extern const TypePtr kTypeExternal;

View File

@ -24,7 +24,7 @@
#include "pybind_api/export_flags.h" #include "pybind_api/export_flags.h"
namespace mindspore { namespace mindspore {
static std::string DumpTypeVector(const std::vector<TypePtr>& elements, bool is_dumptext) { static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {
std::ostringstream oss; std::ostringstream oss;
bool begin = true; bool begin = true;
int cnt = 0; int cnt = 0;
@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const {
} else { } else {
TypePtrList elements; TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr& ele) { return ele->DeepCopy(); }); [](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<List>(elements); auto copy = std::make_shared<List>(elements);
return copy; return copy;
} }
@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const {
return elements_[dim]; return elements_[dim];
} }
bool List::operator==(const Type& other) const { bool List::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
const List& other_list = static_cast<const List&>(other); const List &other_list = static_cast<const List &>(other);
if (elements_.size() != other_list.elements_.size()) { if (elements_.size() != other_list.elements_.size()) {
return false; return false;
} }
@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const {
return true; return true;
} }
Class::Class(const Named& tag, const ClassAttrVector& attributes, Class::Class(const Named &tag, const ClassAttrVector &attributes,
const std::unordered_map<std::string, ValuePtr>& methods) const std::unordered_map<std::string, ValuePtr> &methods)
: Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {} : Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {}
std::string List::ToString() const { std::string List::ToString() const {
@ -122,7 +122,7 @@ std::string List::DumpText() const {
return buffer.str(); return buffer.str();
} }
bool Class::operator==(const Type& other) const { bool Class::operator==(const Type &other) const {
// Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj. // Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj.
return &other == this; return &other == this;
} }
@ -143,7 +143,7 @@ std::string Class::ToString() const {
} else { } else {
bool begin = true; bool begin = true;
buffer << "cls." << tag_ << "["; buffer << "cls." << tag_ << "[";
for (auto& attr : attributes_) { for (auto &attr : attributes_) {
if (!begin) { if (!begin) {
buffer << ", "; buffer << ", ";
} else { } else {
@ -163,7 +163,7 @@ std::string Class::DumpText() const {
} else { } else {
bool begin = true; bool begin = true;
buffer << "Cls." << tag_ << "["; buffer << "Cls." << tag_ << "[";
for (auto& attr : attributes_) { for (auto &attr : attributes_) {
if (!begin) { if (!begin) {
buffer << ", "; buffer << ", ";
} else { } else {
@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const {
} else { } else {
TypePtrList elements; TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr& ele) { return ele->DeepCopy(); }); [](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<Tuple>(elements); auto copy = std::make_shared<Tuple>(elements);
return copy; return copy;
} }
} }
bool Tuple::operator==(const Type& other) const { bool Tuple::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_tuple = static_cast<const Tuple&>(other); auto other_tuple = static_cast<const Tuple &>(other);
if (elements_.size() != other_tuple.elements_.size()) { if (elements_.size() != other_tuple.elements_.size()) {
return false; return false;
} }
@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const {
std::vector<std::pair<std::string, TypePtr>> kv; std::vector<std::pair<std::string, TypePtr>> kv;
(void)std::transform( (void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv), key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, TypePtr>& item) { return std::make_pair(item.first, item.second->DeepCopy()); }); [](const std::pair<std::string, TypePtr> &item) { return std::make_pair(item.first, item.second->DeepCopy()); });
return std::make_shared<Dictionary>(kv); return std::make_shared<Dictionary>(kv);
} }
} }
@ -259,7 +259,7 @@ std::string Dictionary::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
std::vector<std::string> keys; std::vector<std::string> keys;
std::vector<TypePtr> values; std::vector<TypePtr> values;
for (const auto& kv : key_values_) { for (const auto &kv : key_values_) {
keys.push_back(kv.first); keys.push_back(kv.first);
values.push_back(kv.second); values.push_back(kv.second);
} }
@ -276,12 +276,12 @@ std::string Dictionary::ToString() const {
std::string Dictionary::DumpText() const { return ToString(); } std::string Dictionary::DumpText() const { return ToString(); }
bool Dictionary::operator==(const mindspore::Type& other) const { bool Dictionary::operator==(const mindspore::Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
const auto& other_dict = static_cast<const Dictionary&>(other); const auto &other_dict = static_cast<const Dictionary &>(other);
if (key_values_.size() != other_dict.key_values_.size()) { if (key_values_.size() != other_dict.key_values_.size()) {
return false; return false;
} }

View File

@ -40,10 +40,10 @@ namespace mindspore {
class List : public Object { class List : public Object {
public: public:
List() : Object(kObjectTypeList) {} List() : Object(kObjectTypeList) {}
List(const std::initializer_list<TypePtr>& objs) List(const std::initializer_list<TypePtr> &objs)
: Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {}
// Shadow copy; // Shadow copy;
explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {} explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {}
~List() override {} ~List() override {}
MS_DECLARE_PARENT(List, Object) MS_DECLARE_PARENT(List, Object)
@ -51,7 +51,7 @@ class List : public Object {
TypeId generic_type_id() const override { return kObjectTypeList; } TypeId generic_type_id() const override { return kObjectTypeList; }
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
std::size_t size() const { return elements_.size(); } std::size_t size() const { return elements_.size(); }
TypePtrList elements() const { return elements_; } TypePtrList elements() const { return elements_; }
std::string ToString() const override; std::string ToString() const override;
@ -68,22 +68,22 @@ using ClassAttrVector = std::vector<std::pair<std::string, TypePtr>>;
class Class : public Object { class Class : public Object {
public: public:
Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} Class() : Object(kObjectTypeClass), tag_(Named("Class")) {}
Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map<std::string, ValuePtr>& methods); Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map<std::string, ValuePtr> &methods);
~Class() override {} ~Class() override {}
MS_DECLARE_PARENT(Class, Object) MS_DECLARE_PARENT(Class, Object)
TypeId generic_type_id() const override { return kObjectTypeClass; } TypeId generic_type_id() const override { return kObjectTypeClass; }
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
void set_value(const std::unordered_map<std::string, ValuePtr>& v) { attributes_value_ = v; } void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; }
Named tag() { return tag_; } Named tag() { return tag_; }
std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; } std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; }
std::unordered_map<std::string, ValuePtr> methods() { return methods_; } std::unordered_map<std::string, ValuePtr> methods() { return methods_; }
ClassAttrVector& GetAttributes() { return attributes_; } ClassAttrVector &GetAttributes() { return attributes_; }
ClassAttrVector attributes_; ClassAttrVector attributes_;
@ -99,11 +99,11 @@ class Tuple : public Object {
public: public:
Tuple() : Object(kObjectTypeTuple) {} Tuple() : Object(kObjectTypeTuple) {}
// usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)}; // usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)};
Tuple(const std::initializer_list<TypePtr>& objs) Tuple(const std::initializer_list<TypePtr> &objs)
: Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
// Shadow copy // Shadow copy
explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
~Tuple() override {} ~Tuple() override {}
MS_DECLARE_PARENT(Tuple, Object) MS_DECLARE_PARENT(Tuple, Object)
@ -115,7 +115,7 @@ class Tuple : public Object {
std::string ToReprString() const override { return "tuple_"; } std::string ToReprString() const override { return "tuple_"; }
std::string DumpText() const override; std::string DumpText() const override;
const TypePtr operator[](size_t dim) const; const TypePtr operator[](size_t dim) const;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtrList elements() const { return elements_; } TypePtrList elements() const { return elements_; }
std::size_t size() const { return elements_.size(); } std::size_t size() const { return elements_.size(); }
@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr<Tuple>;
class Dictionary : public Object { class Dictionary : public Object {
public: public:
Dictionary() : Object(kObjectTypeDictionary) {} Dictionary() : Object(kObjectTypeDictionary) {}
explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>>& key_values) explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>> &key_values)
: Object(kObjectTypeDictionary, false), key_values_(key_values) {} : Object(kObjectTypeDictionary, false), key_values_(key_values) {}
~Dictionary() override {} ~Dictionary() override {}
@ -136,7 +136,7 @@ class Dictionary : public Object {
TypeId generic_type_id() const override { return kObjectTypeDictionary; } TypeId generic_type_id() const override { return kObjectTypeDictionary; }
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;

View File

@ -24,11 +24,11 @@
#include "pybind_api/export_flags.h" #include "pybind_api/export_flags.h"
namespace mindspore { namespace mindspore {
bool Number::operator==(const Type& other) const { bool Number::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_number = static_cast<const Number&>(other); auto other_number = static_cast<const Number &>(other);
return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_)); return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_));
} }

View File

@ -49,12 +49,12 @@ class Number : public Object {
TypeId type_id() const override { return number_type_; } TypeId type_id() const override { return number_type_; }
TypeId generic_type_id() const override { return kObjectTypeNumber; } TypeId generic_type_id() const override { return kObjectTypeNumber; }
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override { return std::make_shared<Number>(); } TypePtr DeepCopy() const override { return std::make_shared<Number>(); }
std::string ToString() const override { return "Number"; } std::string ToString() const override { return "Number"; }
std::string ToReprString() const override { return "number"; } std::string ToReprString() const override { return "number"; }
std::string DumpText() const override { return "Number"; } std::string DumpText() const override { return "Number"; }
std::string GetTypeName(const std::string& type_name) const { std::string GetTypeName(const std::string &type_name) const {
std::ostringstream oss; std::ostringstream oss;
oss << type_name; oss << type_name;
if (nbits() != 0) { if (nbits() != 0) {

View File

@ -51,7 +51,7 @@ class RefKeyType : public Object {
class RefType : public Object { class RefType : public Object {
public: public:
RefType() : Object(kObjectTypeRef) {} RefType() : Object(kObjectTypeRef) {}
RefType(const TypePtr& subtype, const TypePtr& subtype_origin) RefType(const TypePtr &subtype, const TypePtr &subtype_origin)
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
~RefType() override {} ~RefType() override {}
MS_DECLARE_PARENT(RefType, Object) MS_DECLARE_PARENT(RefType, Object)

View File

@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) {
} }
} }
const char* MetaIdLabel(const TypeId& v) { const char *MetaIdLabel(const TypeId &v) {
switch (v) { switch (v) {
case kTypeUnknown: case kTypeUnknown:
return "kTypeUnknown"; return "kTypeUnknown";
@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) {
} }
} }
const char* ObjectIdLabel(const TypeId& v) { const char *ObjectIdLabel(const TypeId &v) {
switch (v) { switch (v) {
case kObjectTypeNumber: case kObjectTypeNumber:
return "kObjectTypeNumber"; return "kObjectTypeNumber";
@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) {
} }
} }
const char* NumberIdLabel(const TypeId& v) { const char *NumberIdLabel(const TypeId &v) {
switch (v) { switch (v) {
case kNumberTypeBool: case kNumberTypeBool:
return "kNumberTypeBool"; return "kNumberTypeBool";
@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) {
} }
} }
const char* TypeIdLabel(const TypeId& v) { const char *TypeIdLabel(const TypeId &v) {
if (v < kMetaTypeEnd) { if (v < kMetaTypeEnd) {
return MetaIdLabel(v); return MetaIdLabel(v);
} else { } else {
@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) {
} }
} }
bool IsSameObjectType(const Type& lhs, const Type& rhs) { bool IsSameObjectType(const Type &lhs, const Type &rhs) {
if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) { if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
return false; return false;
} }
return lhs.object_type() == rhs.object_type(); return lhs.object_type() == rhs.object_type();
} }
size_t GetTypeByte(const TypePtr& type_ptr) { size_t GetTypeByte(const TypePtr &type_ptr) {
if (type_ptr && type_ptr->isa<Number>()) { if (type_ptr && type_ptr->isa<Number>()) {
auto number = dyn_cast<Number>(type_ptr); auto number = dyn_cast<Number>(type_ptr);
if (!number) { if (!number) {
@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) {
} }
} }
bool Type::operator==(const Value& other) const { bool Type::operator==(const Value &other) const {
if (other.isa<Type>()) { if (other.isa<Type>()) {
auto other_type = static_cast<const Type*>(&other); auto other_type = static_cast<const Type *>(&other);
return *this == *other_type; return *this == *other_type;
} else { } else {
return false; return false;
@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() {
return ptr; return ptr;
} }
std::ostream& operator<<(std::ostream& os, const Type& type) { std::ostream &operator<<(std::ostream &os, const Type &type) {
os << type.ToString(); os << type.ToString();
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const TypePtr type) { std::ostream &operator<<(std::ostream &os, const TypePtr type) {
os << type->ToString(); os << type->ToString();
return os; return os;
} }
@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const {
return false; return false;
} }
std::ostream& operator<<(std::ostream& os, const Object& obj) { std::ostream &operator<<(std::ostream &os, const Object &obj) {
os << obj.ToString(); os << obj.ToString();
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj) { std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
os << obj->ToString(); os << obj->ToString();
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const TypePtrList& types) { std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
os << "["; os << "[";
for (size_t i = 0; i < types.size(); ++i) { for (size_t i = 0; i < types.size(); ++i) {
if (i > 0) { if (i > 0) {

View File

@ -95,10 +95,10 @@ enum TypeId : int {
TypeId IntBitsToTypeId(const int nbits); TypeId IntBitsToTypeId(const int nbits);
TypeId UIntBitsToTypeId(const int nbits); TypeId UIntBitsToTypeId(const int nbits);
TypeId FloatBitsToTypeId(const int nbits); TypeId FloatBitsToTypeId(const int nbits);
const char* TypeIdLabel(const TypeId& v); const char *TypeIdLabel(const TypeId &v);
TypeId NormalizeTypeId(const TypeId type_id); TypeId NormalizeTypeId(const TypeId type_id);
bool IsSameObjectType(const Type& lhs, const Type& rhs); bool IsSameObjectType(const Type &lhs, const Type &rhs);
size_t GetTypeByte(const TypePtr& type_ptr); size_t GetTypeByte(const TypePtr &type_ptr);
// Base class for all types // Base class for all types
// forward declaration. // forward declaration.
@ -110,14 +110,14 @@ class Type : public Value {
~Type() override = default; ~Type() override = default;
MS_DECLARE_PARENT(Type, Value) MS_DECLARE_PARENT(Type, Value)
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
TypeId meta_type() const { return meta_type_; } TypeId meta_type() const { return meta_type_; }
virtual TypeId type_id() const { return meta_type_; } virtual TypeId type_id() const { return meta_type_; }
virtual TypeId generic_type_id() const { return kMetaTypeType; } virtual TypeId generic_type_id() const { return kMetaTypeType; }
virtual bool operator!=(const Type& other) const { return !(*this == other); } virtual bool operator!=(const Type &other) const { return !(*this == other); }
virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); } virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); }
virtual bool equal(const TypePtr other) const { return *this == *other; } virtual bool equal(const TypePtr other) const { return *this == *other; }
virtual TypeId object_type() const { return kTypeUnknown; } virtual TypeId object_type() const { return kTypeUnknown; }
@ -134,8 +134,8 @@ class Type : public Value {
bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } bool IsUnknown() const { return (meta_type_ == kMetaTypeType); }
bool IsGeneric() const { return is_generic_; } bool IsGeneric() const { return is_generic_; }
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
friend std::ostream& operator<<(std::ostream& os, const Type& type); friend std::ostream &operator<<(std::ostream &os, const Type &type);
friend std::ostream& operator<<(std::ostream& os, const TypePtr type); friend std::ostream &operator<<(std::ostream &os, const TypePtr type);
const bool parse_info_ = true; const bool parse_info_ = true;
@ -163,14 +163,14 @@ class Object : public Type {
bool equal(const TypePtr other) const override; bool equal(const TypePtr other) const override;
std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); }
friend std::ostream& operator<<(std::ostream& os, const Object& obj); friend std::ostream &operator<<(std::ostream &os, const Object &obj);
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj); friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj);
private: private:
const TypeId object_type_; const TypeId object_type_;
}; };
std::ostream& operator<<(std::ostream& os, const TypePtrList& types); std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_

View File

@ -61,7 +61,7 @@ FuncGraph::FuncGraph()
AbstractFunctionPtr FuncGraph::abstract() { AbstractFunctionPtr FuncGraph::abstract() {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
for (auto& p : parameters_) { for (auto &p : parameters_) {
MS_EXCEPTION_IF_NULL(p); MS_EXCEPTION_IF_NULL(p);
if (p->abstract() == nullptr) { if (p->abstract() == nullptr) {
MS_LOG(ERROR) << "Error!!"; MS_LOG(ERROR) << "Error!!";
@ -78,7 +78,7 @@ AbstractFunctionPtr FuncGraph::abstract() {
return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract()); return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract());
} }
abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr& context) { abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) {
AnalysisContextPtr temp_context = context; AnalysisContextPtr temp_context = context;
if (temp_context == nullptr) { if (temp_context == nullptr) {
temp_context = abstract::AnalysisContext::DummyContext(); temp_context = abstract::AnalysisContext::DummyContext();
@ -96,7 +96,7 @@ AnfNodePtr FuncGraph::output() const {
} }
} }
void FuncGraph::set_output(const AnfNodePtr& value, bool force_new_ret) { void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
if (force_new_ret || return_ == nullptr) { if (force_new_ret || return_ == nullptr) {
std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value}); std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
@ -125,7 +125,7 @@ ParameterPtr FuncGraph::add_parameter() {
return p; return p;
} }
void FuncGraph::add_parameter(const ParameterPtr& p) { void FuncGraph::add_parameter(const ParameterPtr &p) {
if (manager_.lock()) { if (manager_.lock()) {
std::vector<AnfNodePtr> new_params = parameters_; std::vector<AnfNodePtr> new_params = parameters_;
new_params.push_back(p); new_params.push_back(p);
@ -135,7 +135,7 @@ void FuncGraph::add_parameter(const ParameterPtr& p) {
} }
} }
ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_graph); ParameterPtr p = std::make_shared<Parameter>(this_graph);
p->set_name(name); p->set_name(name);
@ -154,14 +154,14 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) {
return p; return p;
} }
bool FuncGraph::has_flag(const std::string& flag) { bool FuncGraph::has_flag(const std::string &flag) {
if (flags_.count(flag)) { if (flags_.count(flag)) {
return flags_[flag]; return flags_[flag];
} }
return false; return false;
} }
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) { CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
order_.push_back(cnode); order_.push_back(cnode);
@ -170,7 +170,7 @@ CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) {
return cnode; return cnode;
} }
CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, const ScopePtr& scope) { CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope) {
CNodePtr app = NewCNode(inputs); CNodePtr app = NewCNode(inputs);
app->set_scope(scope); app->set_scope(scope);
return app; return app;
@ -178,13 +178,13 @@ CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, con
void FuncGraph::DumpCNodeList() { void FuncGraph::DumpCNodeList() {
MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
for (const auto& cnode : order_) { for (const auto &cnode : order_) {
MS_LOG(INFO) << cnode->DebugString(); MS_LOG(INFO) << cnode->DebugString();
} }
} }
std::string FuncGraph::ToString() const { std::string FuncGraph::ToString() const {
return mindspore::label_manage::Label(const_cast<FuncGraph*>(this)->shared_from_base<FuncGraph>()->debug_info()); return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info());
} }
GraphDebugInfoPtr FuncGraph::debug_info() { GraphDebugInfoPtr FuncGraph::debug_info() {
@ -195,38 +195,38 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
return this->debug_info_; return this->debug_info_;
} }
const AnfNodeSet& FuncGraph::nodes() { const AnfNodeSet &FuncGraph::nodes() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto& nodes = mng->nodes(); auto &nodes = mng->nodes();
return nodes[shared_from_base<FuncGraph>()]; return nodes[shared_from_base<FuncGraph>()];
} }
const AnfNodeCounterMap& FuncGraph::value_nodes() { const AnfNodeCounterMap &FuncGraph::value_nodes() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto& cts = mng->valuenodes(); auto &cts = mng->valuenodes();
return cts[shared_from_base<FuncGraph>()]; return cts[shared_from_base<FuncGraph>()];
} }
const AnfNodeCounterMap& FuncGraph::free_variables_direct() { const AnfNodeCounterMap &FuncGraph::free_variables_direct() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto& fv_direct = mng->free_variables_direct(); auto &fv_direct = mng->free_variables_direct();
return fv_direct[shared_from_base<FuncGraph>()]; return fv_direct[shared_from_base<FuncGraph>()];
} }
const BaseRefCounterMap& FuncGraph::free_variables_total() { const BaseRefCounterMap &FuncGraph::free_variables_total() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto& fv_total = mng->free_variables_total(); auto &fv_total = mng->free_variables_total();
return fv_total[shared_from_base<FuncGraph>()]; return fv_total[shared_from_base<FuncGraph>()];
} }
std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() { std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
std::vector<AnfNodePtr> nodes; std::vector<AnfNodePtr> nodes;
const auto& fv_total = this->free_variables_total(); const auto &fv_total = this->free_variables_total();
for (auto& p : fv_total) { for (auto &p : fv_total) {
auto key = p.first; auto key = p.first;
if (utils::isa<AnfNodePtr>(key)) { if (utils::isa<AnfNodePtr>(key)) {
nodes.push_back(utils::cast<AnfNodePtr>(key)); nodes.push_back(utils::cast<AnfNodePtr>(key));
@ -238,8 +238,8 @@ std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
std::vector<FuncGraphPtr> func_graphs; std::vector<FuncGraphPtr> func_graphs;
const auto& fv_total = this->free_variables_total(); const auto &fv_total = this->free_variables_total();
for (auto& p : fv_total) { for (auto &p : fv_total) {
auto key = p.first; auto key = p.first;
if (utils::isa<FuncGraphPtr>(key)) { if (utils::isa<FuncGraphPtr>(key)) {
func_graphs.push_back(utils::cast<FuncGraphPtr>(key)); func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
@ -249,31 +249,31 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return func_graphs; return func_graphs;
} }
const FuncGraphCounterMap& FuncGraph::func_graphs_used() { const FuncGraphCounterMap &FuncGraph::func_graphs_used() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto& used = mng->func_graphs_used(); auto &used = mng->func_graphs_used();
return used[shared_from_base<FuncGraph>()]; return used[shared_from_base<FuncGraph>()];
} }
const FuncGraphSet& FuncGraph::func_graphs_used_total() { const FuncGraphSet &FuncGraph::func_graphs_used_total() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto& used = mng->func_graphs_used_total(shared_from_base<FuncGraph>()); auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
return used; return used;
} }
const FuncGraphCounterMap& FuncGraph::func_graph_users() { const FuncGraphCounterMap &FuncGraph::func_graph_users() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto& users = mng->func_graph_users(); auto &users = mng->func_graph_users();
return users[shared_from_base<FuncGraph>()]; return users[shared_from_base<FuncGraph>()];
} }
const AnfNodeCounterMap& FuncGraph::func_graph_user_cnodes() { const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto& users = mng->func_graph_user_cnodes(); auto &users = mng->func_graph_user_cnodes();
return users[shared_from_base<FuncGraph>()]; return users[shared_from_base<FuncGraph>()];
} }
@ -288,13 +288,13 @@ FuncGraphPtr FuncGraph::parent() {
return mng->parent(shared_from_base<FuncGraph>()); return mng->parent(shared_from_base<FuncGraph>());
} }
const FuncGraphSet& FuncGraph::children() { const FuncGraphSet &FuncGraph::children() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
return mng->children(shared_from_base<FuncGraph>()); return mng->children(shared_from_base<FuncGraph>());
} }
const FuncGraphSet& FuncGraph::scope() { const FuncGraphSet &FuncGraph::scope() {
auto mng = manager_.lock(); auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
return mng->scopes(shared_from_base<FuncGraph>()); return mng->scopes(shared_from_base<FuncGraph>());
@ -312,9 +312,9 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
return mng->recursive_graphs(shared_from_base<FuncGraph>()); return mng->recursive_graphs(shared_from_base<FuncGraph>());
} }
void FuncGraph::DumpFuncGraph(const std::string& path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); } void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); }
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
auto itr = this->parameter_default_value_.find(name); auto itr = this->parameter_default_value_.find(name);
if (itr == parameter_default_value_.end()) { if (itr == parameter_default_value_.end()) {
return nullptr; return nullptr;
@ -330,9 +330,9 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) {
} }
// set the default values // set the default values
void FuncGraph::SetDefaultValues(const std::vector<std::string>& name_list, const std::vector<AnfNodePtr>& value_list) { void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) {
auto all_is_null = std::all_of(value_list.begin(), value_list.end(), auto all_is_null = std::all_of(value_list.begin(), value_list.end(),
[](const AnfNodePtr& node) { return IsValueNode<NullObj>(node); }); [](const AnfNodePtr &node) { return IsValueNode<NullObj>(node); });
if (value_list.empty()) { if (value_list.empty()) {
all_is_null = true; all_is_null = true;
} }
@ -348,7 +348,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
size_t FuncGraph::GetDefaultValueCount() { size_t FuncGraph::GetDefaultValueCount() {
int null_count = int null_count =
std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
[](const std::pair<std::string, AnfNodePtr>& pair) { return IsValueNode<NullObj>(pair.second); }); [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<NullObj>(pair.second); });
return parameter_default_value_.size() - IntToSize(null_count); return parameter_default_value_.size() - IntToSize(null_count);
} }
@ -425,7 +425,7 @@ int FuncGraph::GetPositionalArgsCount() const {
return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
} }
AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
for (size_t i = 0; i < parameters_.size(); ++i) { for (size_t i = 0; i < parameters_.size(); ++i) {
MS_EXCEPTION_IF_NULL(parameters_[i]); MS_EXCEPTION_IF_NULL(parameters_[i]);
auto param_cast = parameters_[i]->cast<ParameterPtr>(); auto param_cast = parameters_[i]->cast<ParameterPtr>();
@ -437,9 +437,9 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) {
return nullptr; return nullptr;
} }
void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
std::vector<AnfNodePtr>* specialized_parameter_list, std::vector<AnfNodePtr> *specialized_parameter_list,
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes, int variable_args_count, std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
int pos_args_input_count) { int pos_args_input_count) {
// if there is variable argument, pass the input arguments that does not match positional args to it as a tuple // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
if (specialized_graph->has_vararg()) { if (specialized_graph->has_vararg()) {
@ -472,14 +472,14 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph,
} }
} }
void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
std::vector<AnfNodePtr>* specialized_parameter_list, std::vector<AnfNodePtr> *specialized_parameter_list,
const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list, const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) { std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
for (const auto& kwarg : kwarg_list) { for (const auto &kwarg : kwarg_list) {
MS_EXCEPTION_IF_NULL(kwarg); MS_EXCEPTION_IF_NULL(kwarg);
std::string kw_param_name = kwarg->get_key(); std::string kw_param_name = kwarg->get_key();
MS_EXCEPTION_IF_NULL(specialized_graph); MS_EXCEPTION_IF_NULL(specialized_graph);
@ -493,7 +493,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph,
std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
MS_EXCEPTION_IF_NULL(specialized_parameter_list); MS_EXCEPTION_IF_NULL(specialized_parameter_list);
auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
[param_name](const AnfNodePtr& node) { [param_name](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto param = node->cast<ParameterPtr>(); auto param = node->cast<ParameterPtr>();
return param != nullptr && param->name() == param_name; return param != nullptr && param->name() == param_name;
@ -526,10 +526,10 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph,
GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes);
} }
void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes, std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
const std::vector<AnfNodePtr>& kwarg_keys_tuple_nodes, const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
const std::vector<AnfNodePtr>& kwarg_values_tuple_nodes) { const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes) {
if (has_kwarg()) { if (has_kwarg()) {
MS_EXCEPTION_IF_NULL(specialized_graph); MS_EXCEPTION_IF_NULL(specialized_graph);
TraceManager::DebugTrace( TraceManager::DebugTrace(
@ -544,7 +544,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph,
} }
} }
bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list) { bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
// if the function does not have any vararg/kwarg/kwonly/default value/kw args input // if the function does not have any vararg/kwarg/kwonly/default value/kw args input
// return the original graph // return the original graph
if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
@ -558,9 +558,9 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>&
return true; return true;
} }
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
const std::vector<AnfNodePtr>& specialized_parameter_list, const std::vector<AnfNodePtr> &specialized_parameter_list,
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) { std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
MS_EXCEPTION_IF_NULL(specialized_graph); MS_EXCEPTION_IF_NULL(specialized_graph);
for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
auto param_node = specialized_graph->parameters()[i]; auto param_node = specialized_graph->parameters()[i];
@ -583,10 +583,10 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph,
} }
} }
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
size_t arguments_count = args_spec_list.size(); size_t arguments_count = args_spec_list.size();
for (const auto& arg : args_spec_list) { for (const auto &arg : args_spec_list) {
// if it is a keyword argument // if it is a keyword argument
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<abstract::AbstractKeywordArg>()) { if (arg->isa<abstract::AbstractKeywordArg>()) {
@ -619,11 +619,11 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list)
MS_EXCEPTION_IF_NULL(specialized_graph); MS_EXCEPTION_IF_NULL(specialized_graph);
auto params = specialized_graph->parameters(); auto params = specialized_graph->parameters();
(void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(),
std::back_inserter(specialized_parameter_list), [](const AnfNodePtr& node) { return node; }); std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; });
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false); std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
auto tr = manager->Transact(); auto tr = manager->Transact();
for (auto& node_pair : repl_nodes) { for (auto &node_pair : repl_nodes) {
MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-"
<< node_pair.second->DebugString(); << node_pair.second->DebugString();
(void)tr.Replace(node_pair.first, node_pair.second); (void)tr.Replace(node_pair.first, node_pair.second);
@ -638,7 +638,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list)
return specialized_graph; return specialized_graph;
} }
void FuncGraph::add_parameter_obj_node(const AnfNodePtr& p) { paramter_obj_nodes_.push_back(p); } void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
std::list<CNodePtr> FuncGraph::GetOrderedCnodes() { std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
@ -651,7 +651,7 @@ std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
std::list<CNodePtr> cnodes; std::list<CNodePtr> cnodes;
auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph);
for (const auto& node : nodes) { for (const auto &node : nodes) {
auto cnode = dyn_cast<CNode>(node); auto cnode = dyn_cast<CNode>(node);
if (cnode) { if (cnode) {
cnodes.push_back(cnode); cnodes.push_back(cnode);
@ -679,7 +679,7 @@ void FuncGraph::EraseUnusedNodeInOrder() {
} }
} }
void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr& n) { void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) {
if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) { if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) {
order_.remove(n->cast<CNodePtr>()); order_.remove(n->cast<CNodePtr>());
MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list.";
@ -690,7 +690,7 @@ void FuncGraph::CheckOrder() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
MS_LOG(DEBUG) << "Check graph " << ToString(); MS_LOG(DEBUG) << "Check graph " << ToString();
for (auto it = order_.begin(); it != order_.end(); (void)it++) { for (auto it = order_.begin(); it != order_.end(); (void)it++) {
for (const auto& input_node : (*it)->inputs()) { for (const auto &input_node : (*it)->inputs()) {
if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) { if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) {
// Need to reorder the wrong order node. // Need to reorder the wrong order node.
auto found = std::find(order_.begin(), it, input_node); auto found = std::find(order_.begin(), it, input_node);
@ -705,7 +705,7 @@ void FuncGraph::CheckOrder() {
} }
auto mng = manager_.lock(); auto mng = manager_.lock();
if (mng != nullptr) { if (mng != nullptr) {
const auto& nodes = mng->nodes()[shared_from_base<FuncGraph>()]; const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()];
if (nodes.size() != (order_.size() + parameters_.size())) { if (nodes.size() != (order_.size() + parameters_.size())) {
DumpCNodeList(); DumpCNodeList();
MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
@ -718,7 +718,7 @@ void FuncGraph::CheckOrder() {
const char kPrimHasEffect[] = "_side_effect_flag"; const char kPrimHasEffect[] = "_side_effect_flag";
bool FuncGraph::HasEffect(const CNodePtr& cnode) { bool FuncGraph::HasEffect(const CNodePtr &cnode) {
auto prim = GetCNodePrimitive(cnode); auto prim = GetCNodePrimitive(cnode);
if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) { if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) {
auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>(); auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>();
@ -739,9 +739,9 @@ bool FuncGraph::HasEffect(const CNodePtr& cnode) {
return false; return false;
} }
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& segment) { std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment); std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
for (const auto& node : segment) { for (const auto &node : segment) {
if (roots->size() == 1) { if (roots->size() == 1) {
return roots; return roots;
} }
@ -757,9 +757,9 @@ std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& seg
return roots; return roots;
} }
std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr>& segment) { std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment); std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
for (const auto& node : segment) { for (const auto &node : segment) {
if (nodes->size() == 1) { if (nodes->size() == 1) {
return nodes; return nodes;
} }
@ -790,7 +790,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
std::list<AnfNodePtr> depends_order; std::list<AnfNodePtr> depends_order;
std::vector<CNodePtr> segment; std::vector<CNodePtr> segment;
for (const auto& cnode : order_) { for (const auto &cnode : order_) {
if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) {
continue; continue;
} }
@ -830,7 +830,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
} }
} }
void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr>& depend_inputs) { void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
auto old_ret = output(); auto old_ret = output();
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret}; std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret};
(void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end());

View File

@ -26,29 +26,29 @@
// namespace to support intermediate representation definition // namespace to support intermediate representation definition
namespace mindspore { namespace mindspore {
Cloner::Cloner(const FuncGraphPtrList& func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
bool clone_all_used_graphs, const TraceInfoPtr& relation, const TraceInfoPtr& target_relation) bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
: clone_all_valuenodes_(clone_all_valuenodes), : clone_all_valuenodes_(clone_all_valuenodes),
clone_all_child_graphs_(clone_all_child_graphs), clone_all_child_graphs_(clone_all_child_graphs),
clone_all_used_graphs_(clone_all_used_graphs), clone_all_used_graphs_(clone_all_used_graphs),
relation_(relation), relation_(relation),
target_relation_(target_relation == nullptr ? relation : target_relation) { target_relation_(target_relation == nullptr ? relation : target_relation) {
for (auto& func_graph : func_graphs) { for (auto &func_graph : func_graphs) {
AddClone(func_graph); AddClone(func_graph);
} }
scope_ = kDefaultScope; scope_ = kDefaultScope;
type_ = kBasic; type_ = kBasic;
} }
void Cloner::AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList& params, CloneType type) { const AnfNodePtrList &params, CloneType type) {
if (func_graph != nullptr) { if (func_graph != nullptr) {
todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params});
type_ = type; type_ = type;
} }
} }
void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) { if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
return; return;
@ -60,7 +60,7 @@ void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
} }
} }
void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add) { void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target); MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_); TraceManager::DebugTrace(node->debug_info(), relation_);
@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target,
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target); MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_); TraceManager::DebugTrace(node->debug_info(), relation_);
@ -91,7 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
void Cloner::CloneValueNode(const AnfNodePtr& node) { void Cloner::CloneValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
TraceManager::DebugTrace(node->debug_info(), relation_); TraceManager::DebugTrace(node->debug_info(), relation_);
ValueNodePtr new_const = NewValueNode(GetValueNode(node)); ValueNodePtr new_const = NewValueNode(GetValueNode(node));
@ -102,7 +102,7 @@ void Cloner::CloneValueNode(const AnfNodePtr& node) {
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) { void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target); MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_); TraceManager::DebugTrace(node->debug_info(), relation_);
@ -114,14 +114,14 @@ void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target)
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
if (!clone_all_valuenodes_) { if (!clone_all_valuenodes_) {
return; return;
} }
auto& value_nodes = manager_->valuenodes()[func_graph]; auto &value_nodes = manager_->valuenodes()[func_graph];
for (auto& value_node : value_nodes) { for (auto &value_node : value_nodes) {
auto old_node = value_node.first; auto old_node = value_node.first;
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
if (repl_node_.count(old_node) == 0) { if (repl_node_.count(old_node) == 0) {
@ -130,38 +130,38 @@ void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) {
} }
} }
void Cloner::AddChildGraphs(const FuncGraphPtr& func_graph) { void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
if (!clone_all_child_graphs_) { if (!clone_all_child_graphs_) {
return; return;
} }
auto& scopes = manager_->scopes(func_graph); auto &scopes = manager_->scopes(func_graph);
for (auto& graph : scopes) { for (auto &graph : scopes) {
if (graph != func_graph) { if (graph != func_graph) {
todo_.push_back({graph, nullptr, {}}); todo_.push_back({graph, nullptr, {}});
} }
} }
} }
void Cloner::AddTotalGraphs(const FuncGraphPtr& func_graph) { void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
if (!clone_all_used_graphs_) { if (!clone_all_used_graphs_) {
return; return;
} }
auto& used_graphs = manager_->func_graphs_used()[func_graph]; auto &used_graphs = manager_->func_graphs_used()[func_graph];
for (auto& used_graph : used_graphs) { for (auto &used_graph : used_graphs) {
todo_.push_back({used_graph.first, nullptr, {}}); todo_.push_back({used_graph.first, nullptr, {}});
} }
} }
void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
for (auto& item : func_graph->parameter_default_value()) { for (auto &item : func_graph->parameter_default_value()) {
auto nodes = DeepLinkedGraphSearch(item.second); auto nodes = DeepLinkedGraphSearch(item.second);
for (auto& node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
CloneNode(node, target_func_graph); CloneNode(node, target_func_graph);
@ -172,7 +172,7 @@ void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const F
} }
} }
void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
@ -182,15 +182,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const Func
} }
target_func_graph->set_return(return_node); target_func_graph->set_return(return_node);
auto& value_nodes = manager_->func_graph_valuenodes()[func_graph]; auto &value_nodes = manager_->func_graph_valuenodes()[func_graph];
for (auto& value_node : value_nodes) { for (auto &value_node : value_nodes) {
CloneValueNode(value_node.first, target_func_graph); CloneValueNode(value_node.first, target_func_graph);
} }
} }
void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params) { void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto& old_params = func_graph->parameters(); auto &old_params = func_graph->parameters();
if (old_params.size() != params.size()) { if (old_params.size() != params.size()) {
MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]";
return; return;
@ -200,7 +200,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNode
} }
} }
void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph) { void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
@ -215,33 +215,33 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* cons
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
void Cloner::CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
auto& params = func_graph->parameters(); auto &params = func_graph->parameters();
for (auto& param : params) { for (auto &param : params) {
CloneParameter(param, target_func_graph, true); CloneParameter(param, target_func_graph, true);
} }
repl_func_graph_[func_graph] = target_func_graph; repl_func_graph_[func_graph] = target_func_graph;
} }
void Cloner::GenParameters(const FuncGraphPtr& func_graph) { void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto& free_vars = manager_->free_variables_total(); auto &free_vars = manager_->free_variables_total();
auto iter = free_vars.find(func_graph); auto iter = free_vars.find(func_graph);
if (iter == free_vars.end()) { if (iter == free_vars.end()) {
return; return;
} }
for (auto& fv_map : iter->second) { for (auto &fv_map : iter->second) {
auto& free_var = fv_map.first; auto &free_var = fv_map.first;
if (utils::isa<AnfNodePtr>(free_var)) { if (utils::isa<AnfNodePtr>(free_var)) {
repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var))); repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
} }
} }
} }
void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
param->set_abstract(node->abstract()); param->set_abstract(node->abstract());
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
ParameterPtr old_param = dyn_cast<Parameter>(node); ParameterPtr old_param = dyn_cast<Parameter>(node);
@ -252,7 +252,7 @@ void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) {
} }
} }
ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add) { ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(func_graph); ParameterPtr param = std::make_shared<Parameter>(func_graph);
TraceManager::EndTrace(); TraceManager::EndTrace();
@ -265,11 +265,11 @@ ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodeP
return param; return param;
} }
void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params,
AnfNodePtrList* const lift_params, AnfNodePtrList* const input_params) { AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) {
AnfNodePtrList parameters; AnfNodePtrList parameters;
std::unordered_set<AnfNodePtr> old_params; std::unordered_set<AnfNodePtr> old_params;
for (auto& param : func_graph->parameters()) { for (auto &param : func_graph->parameters()) {
auto iter = repl_node_.find(param); auto iter = repl_node_.find(param);
if (iter != repl_node_.end()) { if (iter != repl_node_.end()) {
(void)old_params.insert(iter->second); (void)old_params.insert(iter->second);
@ -280,7 +280,7 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList&
} }
} }
AnfNodePtr new_param = nullptr; AnfNodePtr new_param = nullptr;
for (auto& param : params) { for (auto &param : params) {
auto old_param = repl_node_[param]; auto old_param = repl_node_[param];
if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) { if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
repl_node_[old_param] = old_param; repl_node_[old_param] = old_param;
@ -301,10 +301,10 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList&
func_graph->set_parameters(parameters); func_graph->set_parameters(parameters);
} }
void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList& params) { const AnfNodePtrList &params) {
AnfNodePtr node = nullptr; AnfNodePtr node = nullptr;
auto& repl_func_graph = repl_map_func_graph_[func_graph_user]; auto &repl_func_graph = repl_map_func_graph_[func_graph_user];
auto iter = repl_func_graph.find(func_graph); auto iter = repl_func_graph.find(func_graph);
if (iter == repl_func_graph.end()) { if (iter == repl_func_graph.end()) {
node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
@ -322,9 +322,9 @@ void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr&
OrderParameters(func_graph, inputs); OrderParameters(func_graph, inputs);
} }
void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs) { void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
std::unordered_set<AnfNodePtr> old_params; std::unordered_set<AnfNodePtr> old_params;
for (auto& param : func_graph->parameters()) { for (auto &param : func_graph->parameters()) {
(void)old_params.insert(repl_node_[param]); (void)old_params.insert(repl_node_[param]);
} }
std::unordered_set<AnfNodePtr> new_params; std::unordered_set<AnfNodePtr> new_params;
@ -339,7 +339,7 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis
(void)new_params.insert(new_param); (void)new_params.insert(new_param);
} }
} }
for (auto& param : func_graph->parameters()) { for (auto &param : func_graph->parameters()) {
if (new_params.find(param) == new_params.end()) { if (new_params.find(param) == new_params.end()) {
parameters.push_back(param); parameters.push_back(param);
} }
@ -347,9 +347,9 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis
func_graph->set_parameters(parameters); func_graph->set_parameters(parameters);
} }
void Cloner::SetEdges(const FuncGraphPtr& func_graph) { void Cloner::SetEdges(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
for (auto& node : func_graph->nodes()) { for (auto &node : func_graph->nodes()) {
if (node == nullptr) { if (node == nullptr) {
continue; continue;
} }
@ -358,17 +358,17 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) {
continue; continue;
} }
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
auto& inputs = cnode->inputs(); auto &inputs = cnode->inputs();
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
auto& input = inputs[i]; auto &input = inputs[i];
if (IsValueNode<FuncGraph>(input)) { if (IsValueNode<FuncGraph>(input)) {
auto graph = GetValueNode<FuncGraphPtr>(input); auto graph = GetValueNode<FuncGraphPtr>(input);
auto& repl_func_graph = repl_map_func_graph_[func_graph]; auto &repl_func_graph = repl_map_func_graph_[func_graph];
if (repl_func_graph.find(graph) != repl_func_graph.end()) { if (repl_func_graph.find(graph) != repl_func_graph.end()) {
transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
} }
} else { } else {
auto& repl_node = repl_map_node_[func_graph]; auto &repl_node = repl_map_node_[func_graph];
if (repl_node.find(input) != repl_node.end()) { if (repl_node.find(input) != repl_node.end()) {
transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]);
} }
@ -377,8 +377,8 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) {
} }
} }
void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList& params) { const AnfNodePtrList &params) {
AnfNodePtrList lift_params; AnfNodePtrList lift_params;
AnfNodePtrList input_params; AnfNodePtrList input_params;
AddParameters(func_graph_user, params, &lift_params, &input_params); AddParameters(func_graph_user, params, &lift_params, &input_params);
@ -386,16 +386,16 @@ void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraph
if (lift_params.empty()) { if (lift_params.empty()) {
return; return;
} }
for (auto& user : func_graph_user->func_graph_users()) { for (auto &user : func_graph_user->func_graph_users()) {
LiftParameters(user.first, func_graph_user, lift_params); LiftParameters(user.first, func_graph_user, lift_params);
} }
} }
void Cloner::Lift() { void Cloner::Lift() {
for (auto& func_graph_params : repl_func_graph_params_) { for (auto &func_graph_params : repl_func_graph_params_) {
auto& func_graph = func_graph_params.first; auto &func_graph = func_graph_params.first;
auto& params = func_graph_params.second; auto &params = func_graph_params.second;
for (auto& user : func_graph->func_graph_users()) { for (auto &user : func_graph->func_graph_users()) {
LiftParameters(user.first, func_graph, params); LiftParameters(user.first, func_graph, params);
} }
} }
@ -404,18 +404,18 @@ void Cloner::Lift() {
void Cloner::LiftParameters() { void Cloner::LiftParameters() {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
transaction_ = manager_->Transact(); transaction_ = manager_->Transact();
const FuncGraphSet& func_graphs = manager_->func_graphs(); const FuncGraphSet &func_graphs = manager_->func_graphs();
for (auto& func_graph : func_graphs) { for (auto &func_graph : func_graphs) {
GenParameters(func_graph); GenParameters(func_graph);
} }
Lift(); Lift();
for (auto& func_graph : func_graphs) { for (auto &func_graph : func_graphs) {
SetEdges(func_graph); SetEdges(func_graph);
} }
transaction_.Commit(); transaction_.Commit();
} }
bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
// Make sure only inline once // Make sure only inline once
if (status_.count(func_graph) != 0) { if (status_.count(func_graph) != 0) {
@ -430,12 +430,12 @@ bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) {
return true; return true;
} }
void Cloner::CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
const AnfNodeSet& nodes = manager_->nodes()[func_graph]; const AnfNodeSet &nodes = manager_->nodes()[func_graph];
for (auto& node : nodes) { for (auto &node : nodes) {
CloneNode(node, target_func_graph); CloneNode(node, target_func_graph);
} }
} }
@ -449,7 +449,7 @@ void Cloner::Run() {
// Basic and Inline Clone // Basic and Inline Clone
FuncGraphPtrList func_graphs; FuncGraphPtrList func_graphs;
(void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
[](const CloneInfo& item) -> FuncGraphPtr { return item.origin; }); [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
manager_ = Manage(func_graphs, false); manager_ = Manage(func_graphs, false);
CloneNodes(); CloneNodes();
LinkEdges(); LinkEdges();
@ -495,13 +495,13 @@ void Cloner::CloneNodes() {
} }
void Cloner::LinkEdges() { void Cloner::LinkEdges() {
for (auto& node_pair : nodes_) { for (auto &node_pair : nodes_) {
CNodePtr old_node = node_pair.first; CNodePtr old_node = node_pair.first;
CNodePtr new_node = node_pair.second; CNodePtr new_node = node_pair.second;
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node); MS_EXCEPTION_IF_NULL(new_node);
for (auto& input : old_node->inputs()) { for (auto &input : old_node->inputs()) {
auto& new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
new_node->add_input(new_input); new_node->add_input(new_input);
} }
} }
@ -509,10 +509,10 @@ void Cloner::LinkEdges() {
// For the graphs cloned, update its default value map to the cloned nodes // For the graphs cloned, update its default value map to the cloned nodes
void Cloner::SetDefaults() { void Cloner::SetDefaults() {
for (auto& item : graph_set_) { for (auto &item : graph_set_) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
if (repl_func_graph_.count(item) != 0) { if (repl_func_graph_.count(item) != 0) {
for (auto& param_def : item->parameter_default_value()) { for (auto &param_def : item->parameter_default_value()) {
MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
if (repl_node_.count(param_def.second) != 0) { if (repl_node_.count(param_def.second) != 0) {
repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
@ -524,7 +524,7 @@ void Cloner::SetDefaults() {
} }
} }
AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(root);
if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
@ -537,7 +537,7 @@ AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) {
MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
} }
AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
#ifdef ENABLE_PROFILE #ifdef ENABLE_PROFILE
double time = GetTime(); double time = GetTime();
#endif #endif
@ -548,7 +548,7 @@ AnfNodePtr Cloner::operator[](const AnfNodePtr& node) {
return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
} }
FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
#ifdef ENABLE_PROFILE #ifdef ENABLE_PROFILE
double time = GetTime(); double time = GetTime();
#endif #endif
@ -559,14 +559,14 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) {
return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
} }
FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph) { FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr); Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
return cloner[func_graph]; return cloner[func_graph];
} }
AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList& func_graph_args, const ScopePtr& scope) { const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
Cloner cloner({}, false); Cloner cloner({}, false);
@ -577,14 +577,14 @@ AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& targe
return cloner[func_graph->output()]; return cloner[func_graph->output()];
} }
FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph) { FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
Cloner cloner({}, false); Cloner cloner({}, false);
cloner.AddClone(func_graph, nullptr, {}, kLifting); cloner.AddClone(func_graph, nullptr, {}, kLifting);
return cloner[func_graph]; return cloner[func_graph];
} }
ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphPtrList func_graphs = {func_graph}; FuncGraphPtrList func_graphs = {func_graph};
ClonerPtr cloner = ClonerPtr cloner =
@ -599,14 +599,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& r
return cloner; return cloner;
} }
FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
TraceManager::DebugTrace(func_graph->debug_info(), relation); TraceManager::DebugTrace(func_graph->debug_info(), relation);
auto new_func_graph = std::make_shared<FuncGraph>(); auto new_func_graph = std::make_shared<FuncGraph>();
TraceManager::EndTrace(); TraceManager::EndTrace();
auto& parameters = func_graph->parameters(); auto &parameters = func_graph->parameters();
(void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr& param) -> void { (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr &param) -> void {
MS_EXCEPTION_IF_NULL(param); MS_EXCEPTION_IF_NULL(param);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info()));
(void)new_func_graph->add_parameter(); (void)new_func_graph->add_parameter();
@ -622,7 +622,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoP
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated()); new_func_graph->set_is_generate(func_graph->is_generated());
for (auto& item : func_graph->parameter_default_value()) { for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]); new_func_graph->set_param_default_value(item.first, cloner[item.second]);
} }

View File

@ -43,26 +43,26 @@ struct CloneInfo {
class Cloner { class Cloner {
public: public:
explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false, explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false,
bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, bool clone_all_child_graphs = true, bool clone_all_used_graphs = false,
const TraceInfoPtr& relation = std::make_shared<TraceCopy>(), const TraceInfoPtr &relation = std::make_shared<TraceCopy>(),
const TraceInfoPtr& target_relation = nullptr); const TraceInfoPtr &target_relation = nullptr);
~Cloner() = default; ~Cloner() = default;
void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr, void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr,
const AnfNodePtrList& params = {}, CloneType type = kBasic); const AnfNodePtrList &params = {}, CloneType type = kBasic);
void Run(); void Run();
// Interfaces for specializer // Interfaces for specializer
AnfNodePtr CloneDisconnected(const AnfNodePtr& root); AnfNodePtr CloneDisconnected(const AnfNodePtr &root);
AnfNodePtr operator[](const AnfNodePtr& node); AnfNodePtr operator[](const AnfNodePtr &node);
FuncGraphPtr operator[](const FuncGraphPtr& func_graph); FuncGraphPtr operator[](const FuncGraphPtr &func_graph);
// Map of replicate nodes and graphs // Map of replicate nodes and graphs
std::unordered_map<AnfNodePtr, AnfNodePtr>* cloned_node() { return &repl_node_; } std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; }
std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; } std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; }
// Scope of cloned graphs // Scope of cloned graphs
void set_scope(const ScopePtr& scope) { scope_ = scope; } void set_scope(const ScopePtr &scope) { scope_ = scope; }
const ScopePtr scope() const { return scope_; } const ScopePtr scope() const { return scope_; }
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_; std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_;
@ -71,31 +71,31 @@ class Cloner {
void CloneNodes(); void CloneNodes();
void LinkEdges(); void LinkEdges();
void SetDefaults(); void SetDefaults();
void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target); void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneValueNode(const AnfNodePtr& node); void CloneValueNode(const AnfNodePtr &node);
void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target); void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target); void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false); void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false);
void CloneValueNodes(const FuncGraphPtr& func_graph); void CloneValueNodes(const FuncGraphPtr &func_graph);
void AddChildGraphs(const FuncGraphPtr& func_graph); void AddChildGraphs(const FuncGraphPtr &func_graph);
void AddTotalGraphs(const FuncGraphPtr& func_graph); void AddTotalGraphs(const FuncGraphPtr &func_graph);
bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline); bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline);
void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params); void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph); void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph);
void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void GenParameters(const FuncGraphPtr& func_graph); void GenParameters(const FuncGraphPtr &func_graph);
void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node); void CloneParameter(const ParameterPtr &param, const AnfNodePtr &node);
ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true); ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true);
void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params, void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params, AnfNodePtrList *const lift_params,
AnfNodePtrList* const input_params); AnfNodePtrList *const input_params);
void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params); void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs); void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs);
void SetEdges(const FuncGraphPtr& func_graph); void SetEdges(const FuncGraphPtr &func_graph);
void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList& params); const AnfNodePtrList &params);
void Lift(); void Lift();
void LiftParameters(); void LiftParameters();
@ -118,17 +118,17 @@ class Cloner {
std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_; std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_;
}; };
FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph); FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph);
AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr); const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr);
FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph); FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph);
ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation); ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation);
FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph,
const TraceInfoPtr& relation = std::make_shared<TraceTransform>()); const TraceInfoPtr &relation = std::make_shared<TraceTransform>());
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_

View File

@ -27,17 +27,17 @@
namespace mindspore { namespace mindspore {
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs, bool manage) { FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
auto m = std::make_shared<FuncGraphManager>(func_graphs, manage); auto m = std::make_shared<FuncGraphManager>(func_graphs, manage);
m->Init(); m->Init();
return m; return m;
} }
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage) { FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
FuncGraphManagerPtr m = nullptr; FuncGraphManagerPtr m = nullptr;
bool root = false; bool root = false;
for (auto& fg : func_graphs) { for (auto &fg : func_graphs) {
if (fg == nullptr) { if (fg == nullptr) {
continue; continue;
} }
@ -53,7 +53,7 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool ma
root = true; root = true;
} }
for (auto& fg : func_graphs) { for (auto &fg : func_graphs) {
if (fg == nullptr) { if (fg == nullptr) {
continue; continue;
} }
@ -67,7 +67,7 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) {
return Manage(func_graphs, manage); return Manage(func_graphs, manage);
} }
FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage) FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage)
: roots_(roots), is_manage_(manage) { : roots_(roots), is_manage_(manage) {
Reset(); Reset();
} }
@ -103,12 +103,12 @@ void FuncGraphManager::Init() {
auto roots = roots_; auto roots = roots_;
roots_ = FuncGraphSet(); roots_ = FuncGraphSet();
for (auto& fg : roots) { for (auto &fg : roots) {
AddFuncGraph(fg, true); AddFuncGraph(fg, true);
} }
} }
FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) const { FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
func_graph_parents_total_->Recompute(fg); func_graph_parents_total_->Recompute(fg);
@ -116,7 +116,7 @@ FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg)
return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
} }
FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(func_graph_parent_); MS_EXCEPTION_IF_NULL(func_graph_parent_);
MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
@ -129,7 +129,7 @@ FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const {
return func_graph_parent_->parent_analysis()[fg]; return func_graph_parent_->parent_analysis()[fg];
} }
FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(children_); MS_EXCEPTION_IF_NULL(children_);
MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
@ -137,7 +137,7 @@ FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const {
return children_->children_analysis()[fg]; return children_->children_analysis()[fg];
} }
FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(scopes_); MS_EXCEPTION_IF_NULL(scopes_);
MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString();
@ -146,19 +146,19 @@ FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const {
return scopes_->scope_analysis()[fg]; return scopes_->scope_analysis()[fg];
} }
FVTotalMap& FuncGraphManager::free_variables_total() const { FVTotalMap &FuncGraphManager::free_variables_total() const {
MS_EXCEPTION_IF_NULL(free_variables_total_); MS_EXCEPTION_IF_NULL(free_variables_total_);
free_variables_total_->Recompute(); free_variables_total_->Recompute();
return free_variables_total_->fv_total_analysis(); return free_variables_total_->fv_total_analysis();
} }
FuncGraphSet& FuncGraphManager::func_graphs_used_total(const FuncGraphPtr& fg) const { FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(func_graphs_used_total_); MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
func_graphs_used_total_->Recompute(fg); func_graphs_used_total_->Recompute(fg);
return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
} }
bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
recursive_->Recompute(fg); recursive_->Recompute(fg);
if (recursive_->recursive_analysis().count(fg) == 0) { if (recursive_->recursive_analysis().count(fg) == 0) {
@ -168,7 +168,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const {
return recursive_->recursive_analysis()[fg]; return recursive_->recursive_analysis()[fg];
} }
std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr& fg) const { std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
if (recursive(fg)) { if (recursive(fg)) {
if (!recursive_->recursive_map().count(fg)) { if (!recursive_->recursive_map().count(fg)) {
@ -185,7 +185,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons
} }
} }
bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr& fg) const { bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(j_total_); MS_EXCEPTION_IF_NULL(j_total_);
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
j_total_->Recompute(fg); j_total_->Recompute(fg);
@ -225,10 +225,10 @@ void FuncGraphManager::Clear() {
signals_->InvalidateComputer(); signals_->InvalidateComputer();
} }
void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) { void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) {
MS_LOG(DEBUG) << "Start keep roots"; MS_LOG(DEBUG) << "Start keep roots";
bool root_exist = false; bool root_exist = false;
for (auto& item : func_graphs) { for (auto &item : func_graphs) {
if (roots_.contains(item)) { if (roots_.contains(item)) {
root_exist = true; root_exist = true;
break; break;
@ -245,17 +245,17 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) {
roots = roots_; roots = roots_;
} else { } else {
roots_.clear(); roots_.clear();
for (auto& item : roots) { for (auto &item : roots) {
AddFuncGraph(item, true); AddFuncGraph(item, true);
} }
} }
FuncGraphSet keep; FuncGraphSet keep;
for (auto& item : roots) { for (auto &item : roots) {
MS_LOG(DEBUG) << "roots: " << item->ToString(); MS_LOG(DEBUG) << "roots: " << item->ToString();
keep.update(func_graphs_used_total(item)); keep.update(func_graphs_used_total(item));
#ifdef DEBUG #ifdef DEBUG
for (auto& k : keep) { for (auto &k : keep) {
MS_LOG(DEBUG) << "keep: " << k->ToString(); MS_LOG(DEBUG) << "keep: " << k->ToString();
} }
#endif #endif
@ -264,7 +264,7 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) {
} else { } else {
Clear(); Clear();
FuncGraphSet roots(func_graphs); FuncGraphSet roots(func_graphs);
for (auto& item : roots) { for (auto &item : roots) {
AddFuncGraph(item, true); AddFuncGraph(item, true);
} }
} }
@ -276,7 +276,7 @@ void FuncGraphManager::RemoveRoots() {
MaybeDropFuncGraphs(func_graphs_, true); MaybeDropFuncGraphs(func_graphs_, true);
} }
void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
if (is_manage_) { if (is_manage_) {
if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { if (fg->manager() != nullptr && (&(*fg->manager()) != this)) {
@ -288,7 +288,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) {
func_graphs_.add(fg); func_graphs_.add(fg);
} }
void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users) { void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) {
FuncGraphSet todo(func_graphs); FuncGraphSet todo(func_graphs);
std::set<FuncGraphPtr> dropped; std::set<FuncGraphPtr> dropped;
// int count = 0; // int count = 0;
@ -301,7 +301,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool
continue; continue;
} }
MS_EXCEPTION_IF_NULL(func_graph_users_); MS_EXCEPTION_IF_NULL(func_graph_users_);
auto& users = func_graph_users_->count_func_graphs_map()[func_graph]; auto &users = func_graph_users_->count_func_graphs_map()[func_graph];
if (!users.empty() && !ignore_users) { if (!users.empty() && !ignore_users) {
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
continue; continue;
@ -315,7 +315,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool
todo.update(MaybeDropNodes(return_vec)); todo.update(MaybeDropNodes(return_vec));
} }
MS_EXCEPTION_IF_NULL(signals_); MS_EXCEPTION_IF_NULL(signals_);
for (auto& fg : dropped) { for (auto &fg : dropped) {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
signals_->DropFuncGraph(fg); signals_->DropFuncGraph(fg);
all_nodes_.difference_update(fg->parameters()); all_nodes_.difference_update(fg->parameters());
@ -331,7 +331,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
MS_EXCEPTION_IF_NULL(inp); MS_EXCEPTION_IF_NULL(inp);
if (direction == kDecEdge) { if (direction == kDecEdge) {
MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString();
auto& users_node = node_users_[inp]; auto &users_node = node_users_[inp];
if (!users_node.contains(make_pair(node, index))) { if (!users_node.contains(make_pair(node, index))) {
return; return;
} }
@ -346,26 +346,26 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
AddFuncGraph(GetValueNode<FuncGraphPtr>(inp)); AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
} }
auto& users_node = node_users_[inp]; auto &users_node = node_users_[inp];
users_node.add(make_pair(node, index)); users_node.add(make_pair(node, index));
MS_EXCEPTION_IF_NULL(signals_); MS_EXCEPTION_IF_NULL(signals_);
signals_->AddEdge(node, index, inp); signals_->AddEdge(node, index, inp);
} }
} }
void FuncGraphManager::ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction) { void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
int index = 0; int index = 0;
for (auto& inp : cnode->inputs()) { for (auto &inp : cnode->inputs()) {
ProcessEdge(cnode, index, inp, direction); ProcessEdge(cnode, index, inp, direction);
++index; ++index;
} }
} }
} }
IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) {
if (all_nodes_.contains(node)) { if (all_nodes_.contains(node)) {
return EXCLUDE; return EXCLUDE;
} else { } else {
@ -373,9 +373,9 @@ IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) {
} }
} }
void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) { void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
AnfNodeSet acq; AnfNodeSet acq;
for (auto& node : nodes) { for (auto &node : nodes) {
std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1);
AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit)); AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit));
@ -384,7 +384,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) {
acq.update(new_nodes); acq.update(new_nodes);
} }
for (auto& node : acq) { for (auto &node : acq) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg = node->func_graph(); FuncGraphPtr fg = node->func_graph();
if (fg != nullptr) { if (fg != nullptr) {
@ -395,7 +395,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) {
} }
} }
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& nodes) { FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
AnfNodeSet nodes_ordered(nodes); AnfNodeSet nodes_ordered(nodes);
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
MS_EXCEPTION_IF_NULL(signals_); MS_EXCEPTION_IF_NULL(signals_);
@ -406,7 +406,7 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>&
if (!all_nodes_.contains(node)) { if (!all_nodes_.contains(node)) {
continue; continue;
} }
AnfNodeIndexSet& users = node_users_[node]; AnfNodeIndexSet &users = node_users_[node];
std::vector<AnfNodePtr> parameters; std::vector<AnfNodePtr> parameters;
if (!users.empty() || if (!users.empty() ||
@ -431,13 +431,13 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>&
return func_graphs_to_check; return func_graphs_to_check;
} }
void FuncGraphManager::SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters) { void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters) {
auto tr = Transact(); auto tr = Transact();
tr.SetParameters(fg, parameters); tr.SetParameters(fg, parameters);
tr.Commit(); tr.Commit();
} }
bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
auto tr = Transact(); auto tr = Transact();
bool success = tr.Replace(old_node, new_node); bool success = tr.Replace(old_node, new_node);
if (success) { if (success) {
@ -446,13 +446,13 @@ bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new
return success; return success;
} }
void FuncGraphManager::SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value) { void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
auto tr = Transact(); auto tr = Transact();
tr.SetEdge(node, index, value); tr.SetEdge(node, index, value);
tr.Commit(); tr.Commit();
} }
void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope) { void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) {
AnfNodePtr source_return = source->get_return(); AnfNodePtr source_return = source->get_return();
AnfNodePtr source_output = source->output(); AnfNodePtr source_output = source->output();
AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0); AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
@ -466,23 +466,23 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t
(void)all_nodes_.erase(source_return); (void)all_nodes_.erase(source_return);
(void)node_users_.erase(source_return); (void)node_users_.erase(source_return);
signals_->DropNode(source_return); signals_->DropNode(source_return);
for (auto& node : source->nodes()) { for (auto &node : source->nodes()) {
node->set_func_graph(target); node->set_func_graph(target);
if (node->scope() == kDefaultScope) { if (node->scope() == kDefaultScope) {
node->set_scope(scope); node->set_scope(scope);
} }
} }
for (auto& used : source->func_graphs_used()) { for (auto &used : source->func_graphs_used()) {
(void)func_graph_users_->Inc(used.first, target, used.second); (void)func_graph_users_->Inc(used.first, target, used.second);
(void)this->func_graph_users()[used.first].erase(source); (void)this->func_graph_users()[used.first].erase(source);
} }
for (auto& child : this->func_graph_child_direct()[source]) { for (auto &child : this->func_graph_child_direct()[source]) {
(void)func_graph_parents_direct_->Inc(child.first, target, child.second); (void)func_graph_parents_direct_->Inc(child.first, target, child.second);
(void)this->func_graph_parents_direct()[child.first].erase(source); (void)this->func_graph_parents_direct()[child.first].erase(source);
} }
for (auto& fv_count : this->free_variables_direct()[source]) { for (auto &fv_count : this->free_variables_direct()[source]) {
auto fv_g = fv_count.first->func_graph(); auto fv_g = fv_count.first->func_graph();
auto& count_on_g = this->func_graph_child_direct()[fv_g]; auto &count_on_g = this->func_graph_child_direct()[fv_g];
auto pair = count_on_g.find(source); auto pair = count_on_g.find(source);
if (fv_g != target && pair != count_on_g.end()) { if (fv_g != target && pair != count_on_g.end()) {
(void)func_graph_child_direct_->Inc(fv_g, target, pair->second); (void)func_graph_child_direct_->Inc(fv_g, target, pair->second);
@ -504,9 +504,9 @@ FuncGraphTransaction FuncGraphManager::Transact() {
return tr; return tr;
} }
void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges, void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges,
EdgeTupleCounter* rm_edges, Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms) { EdgeTupleCounter *rm_edges, Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms) {
for (auto& iter : changes) { for (auto &iter : changes) {
auto operation = iter.op; auto operation = iter.op;
auto args = iter.args; auto args = iter.args;
if (operation == Change::kTxSetEdge) { if (operation == Change::kTxSetEdge) {
@ -521,10 +521,10 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl
auto param = args.cast<ArgsOfSetParams>(); auto param = args.cast<ArgsOfSetParams>();
MS_EXCEPTION_IF_NULL(param.func_graph); MS_EXCEPTION_IF_NULL(param.func_graph);
auto old_parameters = param.func_graph->parameters(); auto old_parameters = param.func_graph->parameters();
for (auto& p : param.params) { for (auto &p : param.params) {
(*adds)[p] += 1; (*adds)[p] += 1;
} }
for (auto& p : old_parameters) { for (auto &p : old_parameters) {
(*rms)[p] += 1; (*rms)[p] += 1;
} }
param.func_graph->set_parameters(param.params); param.func_graph->set_parameters(param.params);
@ -532,7 +532,7 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl
} }
} }
void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { void FuncGraphManager::CommitChanges(const std::vector<Change> &changes) {
EdgeTupleCounter add_edges; EdgeTupleCounter add_edges;
EdgeTupleCounter rm_edges; EdgeTupleCounter rm_edges;
Counter<AnfNodePtr> adds; Counter<AnfNodePtr> adds;
@ -540,7 +540,7 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms);
auto sub_edges = add_edges - rm_edges; auto sub_edges = add_edges - rm_edges;
for (auto& iter : sub_edges) { for (auto &iter : sub_edges) {
auto root_node = iter.first.first; auto root_node = iter.first.first;
int index = iter.first.second.first; int index = iter.first.second.first;
auto new_node = iter.first.second.second; auto new_node = iter.first.second.second;
@ -550,12 +550,12 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
auto sub_nodes = adds - rms; auto sub_nodes = adds - rms;
std::vector<AnfNodePtr> nodes; std::vector<AnfNodePtr> nodes;
(void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes),
[](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; }); [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
AcquireNodes(nodes); AcquireNodes(nodes);
auto sub_edges_reverse = rm_edges - add_edges; auto sub_edges_reverse = rm_edges - add_edges;
for (auto& iter : sub_edges_reverse) { for (auto &iter : sub_edges_reverse) {
auto root_node = iter.first.first; auto root_node = iter.first.first;
int index = iter.first.second.first; int index = iter.first.second.first;
auto old_node = iter.first.second.second; auto old_node = iter.first.second.second;
@ -566,17 +566,17 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
std::vector<AnfNodePtr> nodes_reverse; std::vector<AnfNodePtr> nodes_reverse;
(void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse),
[](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; }); [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
auto drop_func_graphs = MaybeDropNodes(nodes_reverse); auto drop_func_graphs = MaybeDropNodes(nodes_reverse);
MaybeDropFuncGraphs(*drop_func_graphs); MaybeDropFuncGraphs(*drop_func_graphs);
} }
void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params) { void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params) {
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
} }
bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node); MS_EXCEPTION_IF_NULL(new_node);
FuncGraphPtr old_func_graph = old_node->func_graph(); FuncGraphPtr old_func_graph = old_node->func_graph();
@ -585,14 +585,14 @@ bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr&
return false; return false;
} }
auto users = manager_->node_users()[old_node]; auto users = manager_->node_users()[old_node];
for (auto& node : users) { for (auto &node : users) {
SetEdge(node.first, node.second, new_node); SetEdge(node.first, node.second, new_node);
} }
return true; return true;
} }
void FuncGraphTransaction::SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v) { void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) {
if (k < 0) { if (k < 0) {
MS_LOG(EXCEPTION) << "Invalid value k = " << k; MS_LOG(EXCEPTION) << "Invalid value k = " << k;
} }
@ -610,7 +610,7 @@ void FuncGraphTransaction::Commit() {
manager_->CommitChanges(changes); manager_->CommitChanges(changes);
} }
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
: manager_(manager), include_func_graph_none_(false) { : manager_(manager), include_func_graph_none_(false) {
manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph);
manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph);
@ -619,7 +619,7 @@ FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager)
manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode);
} }
NodesCollector::NodesCollector(const FuncGraphManager* const m) : DepCollector(m), nodes_analysis_() { NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() {
include_func_graph_none_ = true; include_func_graph_none_ = true;
nodes_analysis_[nullptr] = AnfNodeSet(); nodes_analysis_[nullptr] = AnfNodeSet();
@ -646,7 +646,7 @@ void NodesCollector::OnDropNode(AnfNodePtr n) {
void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// change the owner of node except for the src's return node // change the owner of node except for the src's return node
for (auto& it : nodes_analysis_[src]) { for (auto &it : nodes_analysis_[src]) {
nodes_analysis_[dst].add(it); nodes_analysis_[dst].add(it);
} }
(void)nodes_analysis_.erase(src); (void)nodes_analysis_.erase(src);
@ -654,15 +654,15 @@ void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
DepCollector::DepCollector(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector);
} }
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
auto& d = count_nodes_map_[func_graph]; auto &d = count_nodes_map_[func_graph];
if (d.count(key) == 0) { if (d.count(key) == 0) {
d[key] = count; d[key] = count;
return true; return true;
@ -672,9 +672,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodeP
return false; return false;
} }
bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto& d = count_nodes_map_[func_graph]; auto &d = count_nodes_map_[func_graph];
if (d.count(key) != 0) { if (d.count(key) != 0) {
if (d[key] == count) { if (d[key] == count) {
(void)d.erase(key); (void)d.erase(key);
@ -690,7 +690,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodeP
return false; return false;
} }
bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count) { bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) {
if (count > 0) { if (count > 0) {
return Inc(func_graph, key, count); return Inc(func_graph, key, count);
} else if (count < 0) { } else if (count < 0) {
@ -701,8 +701,8 @@ bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodeP
} }
} }
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto& d = count_func_graphs_map_[func_graph]; auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) { if (d.count(key) == 0) {
d[key] = count; d[key] = count;
return true; return true;
@ -712,8 +712,8 @@ bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGr
return false; return false;
} }
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto& d = count_func_graphs_map_[func_graph]; auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) != 0) { if (d.count(key) != 0) {
if (d[key] == count) { if (d[key] == count) {
(void)d.erase(key); (void)d.erase(key);
@ -729,7 +729,7 @@ bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGr
return false; return false;
} }
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count) { bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
if (count > 0) { if (count > 0) {
return Inc(func_graph, key, count); return Inc(func_graph, key, count);
} else if (count < 0) { } else if (count < 0) {
@ -748,7 +748,7 @@ void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgePr
} }
void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_nodes_map_[src]) { for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second); (void)Inc(dst, it.first, it.second);
} }
(void)count_nodes_map_.erase(src); (void)count_nodes_map_.erase(src);
@ -762,7 +762,7 @@ void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, Ed
} }
void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_nodes_map_[src]) { for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second); (void)Inc(dst, it.first, it.second);
} }
(void)count_nodes_map_.erase(src); (void)count_nodes_map_.erase(src);
@ -779,7 +779,7 @@ void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProc
} }
void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_nodes_map_[src]) { for (auto &it : count_nodes_map_[src]) {
FuncGraphPtr fg2 = it.first->func_graph(); FuncGraphPtr fg2 = it.first->func_graph();
if (fg2 != dst) { if (fg2 != dst) {
(void)Inc(dst, it.first, it.second); (void)Inc(dst, it.first, it.second);
@ -788,7 +788,7 @@ void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
(void)count_nodes_map_.erase(src); (void)count_nodes_map_.erase(src);
} }
static FuncGraphPtr ParentProxy(const FuncGraphPtr& fg) { static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
FuncGraphPtr gn = std::make_shared<FuncGraph>(); FuncGraphPtr gn = std::make_shared<FuncGraph>();
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
return gn; return gn;
@ -805,7 +805,7 @@ void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeP
} }
void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_func_graphs_map_[src]) { for (auto &it : count_func_graphs_map_[src]) {
FuncGraphPtr fg = it.first; FuncGraphPtr fg = it.first;
if (fg != dst) { if (fg != dst) {
(void)Inc(dst, fg, it.second); (void)Inc(dst, fg, it.second);
@ -835,7 +835,7 @@ void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr
} }
void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_func_graphs_map_[src]) { for (auto &it : count_func_graphs_map_[src]) {
if (it.first != dst) { if (it.first != dst) {
(void)Inc(dst, it.first, it.second); (void)Inc(dst, it.first, it.second);
} }
@ -852,7 +852,7 @@ void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, Ed
void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// all graph use in src need to change to dst, so meger the to dst use // all graph use in src need to change to dst, so meger the to dst use
for (auto& it : count_func_graphs_map_[src]) { for (auto &it : count_func_graphs_map_[src]) {
(void)Inc(dst, it.first, it.second); (void)Inc(dst, it.first, it.second);
} }
(void)count_func_graphs_map_[dst].erase(src); (void)count_func_graphs_map_[dst].erase(src);
@ -879,7 +879,7 @@ void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp
} }
void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_nodes_map_[src]) { for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second); (void)Inc(dst, it.first, it.second);
} }
(void)count_nodes_map_.erase(src); (void)count_nodes_map_.erase(src);
@ -895,13 +895,13 @@ void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp,
void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// all graph use in src need to change to dst, so meger the to dst use // all graph use in src need to change to dst, so meger the to dst use
for (auto& it : count_func_graphs_map_[src]) { for (auto &it : count_func_graphs_map_[src]) {
(void)Inc(dst, it.first, it.second); (void)Inc(dst, it.first, it.second);
} }
(void)count_func_graphs_map_.erase(src); (void)count_func_graphs_map_.erase(src);
} }
DepComputer::DepComputer(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
validate_ = false; validate_ = false;
@ -914,20 +914,20 @@ void DepComputer::Recompute() {
} }
} }
void DepComputer::Recompute(const FuncGraphPtr& fg) { void DepComputer::Recompute(const FuncGraphPtr &fg) {
if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
RealRecompute(fg); RealRecompute(fg);
func_graphs_validate_[fg] = true; func_graphs_validate_[fg] = true;
} }
} }
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
if (path == nullptr || path->contains(fg)) { if (path == nullptr || path->contains(fg)) {
return std::make_shared<FuncGraphSet>(); return std::make_shared<FuncGraphSet>();
} }
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
FuncGraphToFuncGraphCounterMap& deps = *all_parents_direct_; FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_;
for (auto& dep : deps[fg]) { for (auto &dep : deps[fg]) {
MS_EXCEPTION_IF_NULL(dep.first); MS_EXCEPTION_IF_NULL(dep.first);
auto proxy = dep.first->transforms().find("proxy"); auto proxy = dep.first->transforms().find("proxy");
if (proxy != dep.first->transforms().end()) { if (proxy != dep.first->transforms().end()) {
@ -950,7 +950,7 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size();
} }
bool set_len_compare(const FuncGraphSetPair& lhs, const FuncGraphSetPair& rhs) { bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
auto l1 = lhs.second.size(); auto l1 = lhs.second.size();
auto l2 = rhs.second.size(); auto l2 = rhs.second.size();
return l1 < l2; return l1 < l2;
@ -970,9 +970,9 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
} else { } else {
// return nearest parent as parent // return nearest parent as parent
FuncGraphSet deps_copy(deps); FuncGraphSet deps_copy(deps);
for (auto& dep : deps) { for (auto &dep : deps) {
auto parent_deps = this->manager_->func_graph_parents_total(dep); auto parent_deps = this->manager_->func_graph_parents_total(dep);
for (auto& p_d : parent_deps) { for (auto &p_d : parent_deps) {
if (deps_copy.count(p_d)) { if (deps_copy.count(p_d)) {
(void)deps_copy.erase(p_d); (void)deps_copy.erase(p_d);
} }
@ -988,7 +988,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
auto used_fg_total = manager_->func_graphs_used_total(fg); auto used_fg_total = manager_->func_graphs_used_total(fg);
for (auto& used_fg : used_fg_total) { for (auto &used_fg : used_fg_total) {
if (manager_->parent(used_fg) == fg) { if (manager_->parent(used_fg) == fg) {
children_analysis_[fg].add(used_fg); children_analysis_[fg].add(used_fg);
} }
@ -997,11 +997,11 @@ void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
void ScopeComputer::RealRecompute(FuncGraphPtr fg) { void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
auto& children = manager_->children(fg); auto &children = manager_->children(fg);
scope_analysis_[fg] = FuncGraphSet(); scope_analysis_[fg] = FuncGraphSet();
scope_analysis_[fg].add(fg); scope_analysis_[fg].add(fg);
for (auto& child : children) { for (auto &child : children) {
scope_analysis_[fg].add(child); scope_analysis_[fg].add(child);
} }
} }
@ -1010,20 +1010,20 @@ void FVTotalComputer::RealRecompute() {
auto manager = DepComputer::manager_; auto manager = DepComputer::manager_;
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
for (auto& fg : manager->func_graphs()) { for (auto &fg : manager->func_graphs()) {
fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>(); fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>();
count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>();
} }
for (auto& fg : manager->func_graphs()) { for (auto &fg : manager->func_graphs()) {
AnfNodeCounterMap items = manager->free_variables_direct()[fg]; AnfNodeCounterMap items = manager->free_variables_direct()[fg];
for (auto& iter : items) { for (auto &iter : items) {
auto curr = fg; auto curr = fg;
while (curr) { while (curr) {
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
curr = manager->parent(curr); curr = manager->parent(curr);
const AnfNodeSet& nodes = manager->nodes()[curr]; const AnfNodeSet &nodes = manager->nodes()[curr];
if (nodes.contains(iter.first)) { if (nodes.contains(iter.first)) {
break; break;
} }
@ -1031,7 +1031,7 @@ void FVTotalComputer::RealRecompute() {
} }
auto items_fg = manager->func_graphs_used()[fg]; auto items_fg = manager->func_graphs_used()[fg];
for (auto& iter : items_fg) { for (auto &iter : items_fg) {
auto p = manager->parent(iter.first); auto p = manager->parent(iter.first);
if (p == nullptr) { if (p == nullptr) {
continue; continue;
@ -1043,13 +1043,13 @@ void FVTotalComputer::RealRecompute() {
} }
} }
} }
for (auto& fg : manager->func_graphs()) { for (auto &fg : manager->func_graphs()) {
auto& fvp = count_nodes_map_[fg]; auto &fvp = count_nodes_map_[fg];
auto& fvg = count_func_graphs_map_[fg]; auto &fvg = count_func_graphs_map_[fg];
for (auto& item : fvp) { for (auto &item : fvp) {
fv_total_analysis_[fg][item.first] = item.second; fv_total_analysis_[fg][item.first] = item.second;
} }
for (auto& item : fvg) { for (auto &item : fvg) {
fv_total_analysis_[fg][item.first] = item.second; fv_total_analysis_[fg][item.first] = item.second;
} }
} }
@ -1057,15 +1057,15 @@ void FVTotalComputer::RealRecompute() {
void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
auto& used = this->manager_->func_graphs_used(); auto &used = this->manager_->func_graphs_used();
std::vector<FuncGraphPtr> todo; std::vector<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo_new; std::vector<FuncGraphPtr> todo_new;
todo.push_back(fg); todo.push_back(fg);
while (!todo.empty()) { while (!todo.empty()) {
todo_new.clear(); todo_new.clear();
for (auto& gt : todo) { for (auto &gt : todo) {
for (auto& item : used[gt]) { for (auto &item : used[gt]) {
auto used_fg = item.first; auto used_fg = item.first;
if (used_fg == fg) { if (used_fg == fg) {
func_graph_used_total_analysis_[fg].add(used_fg); func_graph_used_total_analysis_[fg].add(used_fg);
@ -1082,17 +1082,17 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
} }
} }
bool CheckRecursive(const FuncGraphManager* const manager, const FuncGraphPtr& fg) { bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto& used = manager->func_graphs_used(); auto &used = manager->func_graphs_used();
std::vector<FuncGraphPtr> todo; std::vector<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo_new; std::vector<FuncGraphPtr> todo_new;
todo.push_back(fg); todo.push_back(fg);
FuncGraphSet used_total; FuncGraphSet used_total;
while (!todo.empty()) { while (!todo.empty()) {
todo_new.clear(); todo_new.clear();
for (auto& gt : todo) { for (auto &gt : todo) {
for (auto& item : used[gt]) { for (auto &item : used[gt]) {
auto used_g = item.first; auto used_g = item.first;
if (used_g == fg) { if (used_g == fg) {
return true; return true;
@ -1112,7 +1112,7 @@ void RecursiveComputer::RealRecompute(FuncGraphPtr fg) {
this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
} }
void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace) { void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) {
MS_EXCEPTION_IF_NULL(trace); MS_EXCEPTION_IF_NULL(trace);
auto res = std::find(trace->begin(), trace->end(), fg); auto res = std::find(trace->begin(), trace->end(), fg);
// find recursive // find recursive
@ -1124,7 +1124,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F
} }
} else { } else {
trace->push_back(fg); trace->push_back(fg);
auto& used_fgs = manager_->func_graphs_used()[fg]; auto &used_fgs = manager_->func_graphs_used()[fg];
for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) {
CheckRecursiveGraphs(iter->first, trace); CheckRecursiveGraphs(iter->first, trace);
} }
@ -1135,14 +1135,14 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F
} }
} }
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
MS_EXCEPTION_IF_NULL(path); MS_EXCEPTION_IF_NULL(path);
if (path->contains(fg)) { if (path->contains(fg)) {
MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked"; MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked";
return false; return false;
} }
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
auto& func_graph_counter_map = manager_->func_graph_j_direct(); auto &func_graph_counter_map = manager_->func_graph_j_direct();
if (!func_graph_counter_map[fg].empty()) { if (!func_graph_counter_map[fg].empty()) {
// check g1->J(fg)->g2->g cycle; // check g1->J(fg)->g2->g cycle;
auto contains_j = auto contains_j =
@ -1156,8 +1156,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt
path->add(fg); path->add(fg);
// check if func graphs used contains J(func_graph); // check if func graphs used contains J(func_graph);
auto& used = this->manager_->func_graphs_used(); auto &used = this->manager_->func_graphs_used();
for (auto& item : used[fg]) { for (auto &item : used[fg]) {
auto used_g = item.first; auto used_g = item.first;
if (SeekJ(used_g, path)) { if (SeekJ(used_g, path)) {
MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString() MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString()

View File

@ -46,13 +46,13 @@ class FuncGraphManager;
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>; using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
struct AnfNodeIndexPairHasher { struct AnfNodeIndexPairHasher {
std::size_t operator()(const std::pair<AnfNodePtr, int>& p1) const { std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const {
return std::hash<const AnfNode*>{}(p1.first.get()); return std::hash<const AnfNode *>{}(p1.first.get());
} }
}; };
struct AnfNodeIndexPairEqual { struct AnfNodeIndexPairEqual {
bool operator()(const std::pair<AnfNodePtr, int>& lhs, const std::pair<AnfNodePtr, int>& rhs) const { bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const {
return lhs == rhs; return lhs == rhs;
} }
}; };
@ -63,14 +63,14 @@ using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>; using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>; using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>;
struct EdgeTupleHasher { struct EdgeTupleHasher {
std::size_t operator()(const EdgeTuple& p1) const { std::size_t operator()(const EdgeTuple &p1) const {
return hash_combine({std::hash<AnfNode*>{}(p1.first.get()), std::hash<int>{}(p1.second.first), return hash_combine({std::hash<AnfNode *>{}(p1.first.get()), std::hash<int>{}(p1.second.first),
std::hash<AnfNode*>{}(p1.second.second.get())}); std::hash<AnfNode *>{}(p1.second.second.get())});
} }
}; };
struct EdgeTupleEqual { struct EdgeTupleEqual {
bool operator()(const EdgeTuple& lhs, const EdgeTuple& rhs) const { bool operator()(const EdgeTuple &lhs, const EdgeTuple &rhs) const {
return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second; return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second;
} }
}; };
@ -82,9 +82,9 @@ using EdgeTupleCounter = Counter<EdgeTuple, EdgeTupleHasher, EdgeTupleEqual>;
// FuncGraphManagerPtr: return created manager // FuncGraphManagerPtr: return created manager
FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true); FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true);
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage = true); FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true);
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs = {}, bool manage = true); FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true);
struct Signals { struct Signals {
Signal<void(FuncGraphPtr)> AddFuncGraph; Signal<void(FuncGraphPtr)> AddFuncGraph;
@ -106,7 +106,7 @@ using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNode
// analysis base class // analysis base class
class FuncGraphAnalysis { class FuncGraphAnalysis {
public: public:
explicit FuncGraphAnalysis(const FuncGraphManager* const manager); explicit FuncGraphAnalysis(const FuncGraphManager *const manager);
virtual ~FuncGraphAnalysis() { manager_ = nullptr; } virtual ~FuncGraphAnalysis() { manager_ = nullptr; }
@ -130,7 +130,7 @@ class FuncGraphAnalysis {
virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {} virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {}
const FuncGraphManager* manager_; const FuncGraphManager *manager_;
bool include_func_graph_none_; bool include_func_graph_none_;
}; };
@ -139,7 +139,7 @@ using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;
// graphs analysis which compute in write, read needn't recompute // graphs analysis which compute in write, read needn't recompute
class DepCollector : public FuncGraphAnalysis { class DepCollector : public FuncGraphAnalysis {
public: public:
explicit DepCollector(const FuncGraphManager* manager); explicit DepCollector(const FuncGraphManager *manager);
~DepCollector() override = default; ~DepCollector() override = default;
void Reset() { ExtraReset(); } void Reset() { ExtraReset(); }
@ -155,10 +155,10 @@ class DepCollector : public FuncGraphAnalysis {
class NodesCollector final : public DepCollector { class NodesCollector final : public DepCollector {
public: public:
explicit NodesCollector(const FuncGraphManager* m); explicit NodesCollector(const FuncGraphManager *m);
~NodesCollector() override = default; ~NodesCollector() override = default;
const FuncGraphToAnfNodeMap& nodes_analysis() const { return nodes_analysis_; } const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; }
size_t size() const override { return nodes_analysis_.size(); } size_t size() const override { return nodes_analysis_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); }
@ -176,16 +176,16 @@ class NodesCollector final : public DepCollector {
class CounterFuncGraphCollector : public DepCollector { class CounterFuncGraphCollector : public DepCollector {
public: public:
explicit CounterFuncGraphCollector(const FuncGraphManager* m) : DepCollector(m) {} explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterFuncGraphCollector() override = default; ~CounterFuncGraphCollector() override = default;
FuncGraphToFuncGraphCounterMap& count_func_graphs_map() { return count_func_graphs_map_; } FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
// inherit from FuncGraphAnalysis // inherit from FuncGraphAnalysis
size_t size() const override { return count_func_graphs_map_.size(); } size_t size() const override { return count_func_graphs_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
bool Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
FuncGraphToFuncGraphCounterMap count_func_graphs_map_; FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
@ -195,17 +195,17 @@ class CounterFuncGraphCollector : public DepCollector {
class CounterAnfNodeCollector : public DepCollector { class CounterAnfNodeCollector : public DepCollector {
public: public:
explicit CounterAnfNodeCollector(const FuncGraphManager* m) : DepCollector(m) {} explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default; ~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap& count_nodes_map() { return count_nodes_map_; } FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; }
size_t size() const override { return count_nodes_map_.size(); } size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); } void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
bool Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
bool Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
bool Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
FuncGraphToAnfNodeCounterMap count_nodes_map_; FuncGraphToAnfNodeCounterMap count_nodes_map_;
@ -215,7 +215,7 @@ class CounterAnfNodeCollector : public DepCollector {
class ValueNodesCollector final : public CounterAnfNodeCollector { class ValueNodesCollector final : public CounterAnfNodeCollector {
public: public:
explicit ValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~ValueNodesCollector() override = default; ~ValueNodesCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
@ -225,7 +225,7 @@ class ValueNodesCollector final : public CounterAnfNodeCollector {
class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
public: public:
explicit FuncGraphValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FuncGraphValueNodesCollector() override = default; ~FuncGraphValueNodesCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
@ -235,7 +235,7 @@ class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
class FVDirectCollector final : public CounterAnfNodeCollector { class FVDirectCollector final : public CounterAnfNodeCollector {
public: public:
explicit FVDirectCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FVDirectCollector() override = default; ~FVDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
@ -245,7 +245,7 @@ class FVDirectCollector final : public CounterAnfNodeCollector {
class FuncGraphChildDirect final : public CounterFuncGraphCollector { class FuncGraphChildDirect final : public CounterFuncGraphCollector {
public: public:
explicit FuncGraphChildDirect(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphChildDirect() override = default; ~FuncGraphChildDirect() override = default;
@ -260,7 +260,7 @@ class FuncGraphChildDirect final : public CounterFuncGraphCollector {
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f // 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
public: public:
explicit FuncGraphParentsDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
~FuncGraphParentsDirectCollector() override = default; ~FuncGraphParentsDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
@ -271,7 +271,7 @@ class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
// graph's all used graphs: key is g, value is g used graph // graph's all used graphs: key is g, value is g used graph
class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
public: public:
explicit FuncGraphsUsedCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphsUsedCollector() override = default; ~FuncGraphsUsedCollector() override = default;
@ -282,7 +282,7 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
// graph's all user graphs: key is g, value is graphs who used g // graph's all user graphs: key is g, value is graphs who used g
class FuncGraphUsersCollector final : public CounterFuncGraphCollector { class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
public: public:
explicit FuncGraphUsersCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphUsersCollector() override = default; ~FuncGraphUsersCollector() override = default;
@ -293,7 +293,7 @@ class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
// graph's all user cnodes: key is g, value is cnodes who used g // graph's all user cnodes: key is g, value is cnodes who used g
class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
public: public:
explicit FuncGraphUserNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphUserNodesCollector() override = default; ~FuncGraphUserNodesCollector() override = default;
@ -303,7 +303,7 @@ class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
public: public:
explicit FuncGraphJDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override;
~FuncGraphJDirectCollector() override = default; ~FuncGraphJDirectCollector() override = default;
@ -316,7 +316,7 @@ using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
// graphs analysis which need dynamic compute by DepCollector in each read // graphs analysis which need dynamic compute by DepCollector in each read
class DepComputer : public FuncGraphAnalysis { class DepComputer : public FuncGraphAnalysis {
public: public:
explicit DepComputer(const FuncGraphManager* manager); explicit DepComputer(const FuncGraphManager *manager);
~DepComputer() override = default; ~DepComputer() override = default;
void Reset() { void Reset() {
@ -329,11 +329,11 @@ class DepComputer : public FuncGraphAnalysis {
void Recompute(); void Recompute();
void Recompute(const FuncGraphPtr& fg); void Recompute(const FuncGraphPtr &fg);
bool IsValidate() const { return validate_; } bool IsValidate() const { return validate_; }
bool IsValidate(const FuncGraphPtr& fg) { return func_graphs_validate_[fg]; } bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } void OnAddFuncGraph(FuncGraphPtr) final { Reset(); }
@ -354,10 +354,10 @@ class DepComputer : public FuncGraphAnalysis {
// graph g's all direct or proxy parents // graph g's all direct or proxy parents
class FuncGraphParentsTotalComputer final : public DepComputer { class FuncGraphParentsTotalComputer final : public DepComputer {
public: public:
explicit FuncGraphParentsTotalComputer(const FuncGraphManager* m) : DepComputer(m), all_parents_direct_(nullptr) {} explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {}
~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; }
FuncGraphToFuncGraphSetMap& func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
size_t size() const override { return func_graph_parents_total_analysis_.size(); } size_t size() const override { return func_graph_parents_total_analysis_.size(); }
@ -369,10 +369,10 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;
private: private:
FuncGraphSetPtr SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>()); FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>());
// when SeekParents calls itself recursively, it can access these variables by class member // when SeekParents calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 1 parameter for SeekParents(). // other than pass by formal parameters, it can save 1 parameter for SeekParents().
FuncGraphToFuncGraphCounterMap* all_parents_direct_; FuncGraphToFuncGraphCounterMap *all_parents_direct_;
}; };
using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
@ -380,10 +380,10 @@ using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
// graph's nearest parent in parents total // graph's nearest parent in parents total
class ParentComputer final : public DepComputer { class ParentComputer final : public DepComputer {
public: public:
explicit ParentComputer(const FuncGraphManager* m) : DepComputer(m) {} explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {}
~ParentComputer() override = default; ~ParentComputer() override = default;
FuncGraphToFuncGraphMap& parent_analysis() { return parent_analysis_; } FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; }
size_t size() const override { return parent_analysis_.size(); } size_t size() const override { return parent_analysis_.size(); }
@ -398,10 +398,10 @@ class ParentComputer final : public DepComputer {
// graph's children graph except self // graph's children graph except self
class ChildrenComputer final : public DepComputer { class ChildrenComputer final : public DepComputer {
public: public:
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {} explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {}
~ChildrenComputer() override = default; ~ChildrenComputer() override = default;
FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; } FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; }
size_t size() const override { return children_analysis_.size(); } size_t size() const override { return children_analysis_.size(); }
@ -416,10 +416,10 @@ class ChildrenComputer final : public DepComputer {
// graph's children graph include self // graph's children graph include self
class ScopeComputer final : public DepComputer { class ScopeComputer final : public DepComputer {
public: public:
explicit ScopeComputer(const FuncGraphManager* m) : DepComputer(m) {} explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {}
~ScopeComputer() override = default; ~ScopeComputer() override = default;
FuncGraphToFuncGraphSetMap& scope_analysis() { return scope_analysis_; } FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; }
size_t size() const override { return scope_analysis_.size(); } size_t size() const override { return scope_analysis_.size(); }
@ -435,11 +435,11 @@ using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash
class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector {
public: public:
explicit FVTotalComputer(const FuncGraphManager* m) explicit FVTotalComputer(const FuncGraphManager *m)
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
~FVTotalComputer() override = default; ~FVTotalComputer() override = default;
FVTotalMap& fv_total_analysis() { return fv_total_analysis_; } FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }
size_t size() const override { return fv_total_analysis_.size(); } size_t size() const override { return fv_total_analysis_.size(); }
@ -453,10 +453,10 @@ class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector
class FuncGraphsUsedTotalComputer final : public DepComputer { class FuncGraphsUsedTotalComputer final : public DepComputer {
public: public:
explicit FuncGraphsUsedTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
~FuncGraphsUsedTotalComputer() override = default; ~FuncGraphsUsedTotalComputer() override = default;
FuncGraphToFuncGraphSetMap& func_graph_used_total_analysis() { return func_graph_used_total_analysis_; } FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
size_t size() const override { return func_graph_used_total_analysis_.size(); } size_t size() const override { return func_graph_used_total_analysis_.size(); }
@ -473,13 +473,13 @@ using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGrap
class RecursiveComputer final : public DepComputer { class RecursiveComputer final : public DepComputer {
public: public:
explicit RecursiveComputer(const FuncGraphManager* m) : DepComputer(m) {} explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {}
~RecursiveComputer() override = default; ~RecursiveComputer() override = default;
RecursiveMap& recursive_map() { return recursive_map_; } RecursiveMap &recursive_map() { return recursive_map_; }
FuncGraphToBoolMap& recursive_analysis() { return recursive_analysis_; } FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; }
void CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace); void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace);
size_t size() const override { return recursive_analysis_.size(); } size_t size() const override { return recursive_analysis_.size(); }
@ -497,10 +497,10 @@ class RecursiveComputer final : public DepComputer {
class FuncGraphJTotalComputer final : public DepComputer { class FuncGraphJTotalComputer final : public DepComputer {
public: public:
explicit FuncGraphJTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
~FuncGraphJTotalComputer() override = default; ~FuncGraphJTotalComputer() override = default;
FuncGraphToBoolMap& j_total_analysis() { return j_total_analysis_; } FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; }
size_t size() const override { return j_total_analysis_.size(); } size_t size() const override { return j_total_analysis_.size(); }
@ -510,12 +510,12 @@ class FuncGraphJTotalComputer final : public DepComputer {
void ExtraReset() override { j_total_analysis_.clear(); } void ExtraReset() override { j_total_analysis_.clear(); }
void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;
bool SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path); bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path);
}; };
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
public: public:
explicit FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage = true); explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
~FuncGraphManager() { ~FuncGraphManager() {
if (is_manage_) { if (is_manage_) {
RemoveRoots(); RemoveRoots();
@ -526,71 +526,71 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
void Init(); void Init();
void Clear(); void Clear();
void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false); void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false);
void KeepRoots(const std::vector<FuncGraphPtr>& roots = {}); void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
void RemoveRoots(); void RemoveRoots();
void SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters); void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters);
void MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users = false); void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value); void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope); void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope);
FuncGraphTransaction Transact(); FuncGraphTransaction Transact();
void CommitChanges(const std::vector<Change>& changes); void CommitChanges(const std::vector<Change> &changes);
bool IsManaged() const { return is_manage_; } bool IsManaged() const { return is_manage_; }
const FuncGraphSet& roots() const { return roots_; } const FuncGraphSet &roots() const { return roots_; }
const FuncGraphSet& func_graphs() const { return func_graphs_; } const FuncGraphSet &func_graphs() const { return func_graphs_; }
AnfNodeSet& all_nodes() { return all_nodes_; } AnfNodeSet &all_nodes() { return all_nodes_; }
NodeUsersMap& node_users() { return node_users_; } NodeUsersMap &node_users() { return node_users_; }
FuncGraphToAnfNodeMap& nodes() const { return nodes_->nodes_analysis_; } FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }
FuncGraphToAnfNodeCounterMap& valuenodes() const { return valuenodes_->count_nodes_map_; } FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap& free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap& func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; }
FuncGraphToFuncGraphCounterMap& func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap& func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; }
FuncGraphToAnfNodeCounterMap& func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; }
FuncGraphToFuncGraphCounterMap& func_graph_child_direct() const { FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
return func_graph_child_direct_->count_func_graphs_map_; return func_graph_child_direct_->count_func_graphs_map_;
} }
FuncGraphToFuncGraphCounterMap& func_graph_parents_direct() const { FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const {
return func_graph_parents_direct_->count_func_graphs_map_; return func_graph_parents_direct_->count_func_graphs_map_;
} }
FuncGraphToFuncGraphCounterMap& func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; }
FVTotalMap& free_variables_total() const; FVTotalMap &free_variables_total() const;
FuncGraphSet& func_graph_parents_total(const FuncGraphPtr& fg) const; FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
FuncGraphSet& scopes(const FuncGraphPtr& fg) const; FuncGraphSet &scopes(const FuncGraphPtr &fg) const;
FuncGraphPtr parent(const FuncGraphPtr& fg) const; FuncGraphPtr parent(const FuncGraphPtr &fg) const;
FuncGraphSet& children(const FuncGraphPtr& fg) const; FuncGraphSet &children(const FuncGraphPtr &fg) const;
FuncGraphSet& func_graphs_used_total(const FuncGraphPtr& fg) const; FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const;
bool recursive(const FuncGraphPtr& fg) const; bool recursive(const FuncGraphPtr &fg) const;
std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr& fg) const; std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const;
bool func_graph_j_total(const FuncGraphPtr& fg) const; bool func_graph_j_total(const FuncGraphPtr &fg) const;
std::shared_ptr<Signals> signals() const { return signals_; } std::shared_ptr<Signals> signals() const { return signals_; }
IncludeType Limit(const AnfNodePtr& node); IncludeType Limit(const AnfNodePtr &node);
// Static Analysis // Static Analysis
NodeUsersMap node_users_; NodeUsersMap node_users_;
@ -610,13 +610,13 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
std::shared_ptr<ParentComputer> func_graph_parent_; std::shared_ptr<ParentComputer> func_graph_parent_;
private: private:
void AddIntoManaged(const FuncGraphPtr& fg); void AddIntoManaged(const FuncGraphPtr &fg);
void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction); void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction);
void ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction); void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction);
void AcquireNodes(const std::vector<AnfNodePtr>& nodes); void AcquireNodes(const std::vector<AnfNodePtr> &nodes);
FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr>& nodes); FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes);
void ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges, EdgeTupleCounter* rm_edges, void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges,
Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms); Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms);
FuncGraphSet roots_; // managed roots FuncGraphSet roots_; // managed roots
FuncGraphSet func_graphs_; // managed func graphs FuncGraphSet func_graphs_; // managed func graphs
@ -637,7 +637,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
class FuncGraphTransaction { class FuncGraphTransaction {
public: public:
explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() { explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager), changes_() {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
if (!manager_->IsManaged()) { if (!manager_->IsManaged()) {
MS_LOG(DEBUG) << "The manager is not managed yet"; MS_LOG(DEBUG) << "The manager is not managed yet";
@ -648,19 +648,19 @@ class FuncGraphTransaction {
~FuncGraphTransaction() { manager_ = nullptr; } ~FuncGraphTransaction() { manager_ = nullptr; }
// set parameters of a func graph // set parameters of a func graph
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params); void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params);
// replace old_node with new_node // replace old_node with new_node
bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
// set esge, i.e., declare setting node.inputs[key] to value. // set esge, i.e., declare setting node.inputs[key] to value.
void SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v); void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v);
// commit all changes // commit all changes
void Commit(); void Commit();
private: private:
FuncGraphManager* manager_; FuncGraphManager *manager_;
std::vector<Change> changes_; std::vector<Change> changes_;
}; };
@ -668,9 +668,9 @@ class FuncGraphTransaction {
struct ArgsOfSetParams { struct ArgsOfSetParams {
FuncGraphPtr func_graph; FuncGraphPtr func_graph;
std::vector<AnfNodePtr> params; std::vector<AnfNodePtr> params;
bool operator==(const ArgsOfSetParams& other) const { return &other == this; } bool operator==(const ArgsOfSetParams &other) const { return &other == this; }
friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetParams&) { friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetParams &) {
os << "[ArgsOfSetParams]"; os << "[ArgsOfSetParams]";
return os; return os;
} }
@ -681,9 +681,9 @@ struct ArgsOfSetEdge {
CNodePtr root_node; CNodePtr root_node;
AnfNodePtr new_node; AnfNodePtr new_node;
size_t index; size_t index;
bool operator==(const ArgsOfSetEdge& other) const { return &other == this; } bool operator==(const ArgsOfSetEdge &other) const { return &other == this; }
friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetEdge& other) { friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetEdge &other) {
os << "[ArgsOfSetEdge]"; os << "[ArgsOfSetEdge]";
return os; return os;
} }
@ -693,7 +693,7 @@ struct Change {
enum OpName { kTxSetParams, kTxSetEdge }; enum OpName { kTxSetParams, kTxSetEdge };
OpName op; OpName op;
Any args; Any args;
Change(OpName name, const Any& para) : op(name), args(para) {} Change(OpName name, const Any &para) : op(name), args(para) {}
}; };
} // namespace mindspore } // namespace mindspore

View File

@ -42,25 +42,25 @@ namespace mindspore {
// generate a graph corresponding to these types. // generate a graph corresponding to these types.
class MetaFuncGraph : public FuncGraphBase { class MetaFuncGraph : public FuncGraphBase {
public: public:
explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); } explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); }
~MetaFuncGraph() override = default; ~MetaFuncGraph() override = default;
MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase);
abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node); abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node);
// Return normalized versions of the arguments. // Return normalized versions of the arguments.
// By default, this returns args unchanged. // By default, this returns args unchanged.
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const { virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
return args_spec_list; return args_spec_list;
} }
const std::vector<Signature>& signatures() const { return signatures_; } const std::vector<Signature> &signatures() const { return signatures_; }
void set_signatures(const std::vector<Signature>& signatures) { signatures_ = signatures; } void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
// Generate a Graph for the given abstract arguments. // Generate a Graph for the given abstract arguments.
virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) { virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
TypePtrList types; TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
[](const AbstractBasePtr& arg) -> TypePtr { [](const AbstractBasePtr &arg) -> TypePtr {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
return arg->BuildType(); return arg->BuildType();
}); });
@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase {
} }
// Generate a Graph for this type signature. // Generate a Graph for this type signature.
virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) { virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) {
MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types.";
} }
@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase {
std::string ToString() const override { return name_; } std::string ToString() const override { return name_; }
std::size_t hash() const override { return tid(); } std::size_t hash() const override { return tid(); }
virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; } virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; }
bool operator==(const Value& other) const override { bool operator==(const Value &other) const override {
if (other.isa<MetaFuncGraph>()) { if (other.isa<MetaFuncGraph>()) {
return &other == this; return &other == this;
} else { } else {

View File

@ -31,7 +31,7 @@ namespace mindspore {
namespace tensor { namespace tensor {
void DataBuf2Contiguous(const py::array& src, py::array* const dest) { void DataBuf2Contiguous(const py::array &src, py::array *const dest) {
if (dest == nullptr) { if (dest == nullptr) {
MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!";
} }
@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) {
// MetaTensor has default type_id_ which is TypeId::kTypeUnknown. // MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int>& shape) : data_type_(data_type), shape_(shape) {} MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int> &shape) : data_type_(data_type), shape_(shape) {}
MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) {
TypeId data_type = TypeId::kTypeUnknown; TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) { if (type_ptr != nullptr) {
data_type = type_ptr->type_id(); data_type = type_ptr->type_id();
@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) {
} }
} }
MetaTensor::MetaTensor(const MetaTensor& meta_tensor) MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
: Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {} : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
if (&meta_tensor == this) { if (&meta_tensor == this) {
return *this; return *this;
} }
@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) {
return *this; return *this;
} }
bool MetaTensor::operator==(const MetaTensor& meta_tensor) const { bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape(); return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
} }
@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
return type_ptr; return type_ptr;
} }
void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) { void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) {
DeviceInfo info(format, data_type); DeviceInfo info(format, data_type);
set_device_info(info); set_device_info(info);
} }
@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const {
return oss.str(); return oss.str();
} }
Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) {
TypeId data_type = TypeId::kTypeUnknown; TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) { if (type_ptr != nullptr) {
data_type = type_ptr->type_id(); data_type = type_ptr->type_id();
@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) {
init(data_type_, shape_, &data_); init(data_type_, shape_, &data_);
} }
Tensor::Tensor(TypeId data_type, const std::vector<int>& shape) { init(data_type, shape, &data_); } Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { init(data_type, shape, &data_); }
Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); } Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); }
Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); } Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); } Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); } Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); } Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type) Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
: MetaTensor(tensor), device_address_(tensor.device_address()) { : MetaTensor(tensor), device_address_(tensor.device_address()) {
init(tensor.data_, data_type); init(tensor.data_, data_type);
} }
Tensor& Tensor::operator=(const Tensor& tensor) { Tensor &Tensor::operator=(const Tensor &tensor) {
if (this != &tensor) { if (this != &tensor) {
MetaTensor::operator=(tensor); MetaTensor::operator=(tensor);
dirty_ = tensor.is_dirty(); dirty_ = tensor.is_dirty();
@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) {
return *this; return *this;
} }
bool Tensor::operator==(const Tensor& tensor) const { bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_); return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
} }
bool Tensor::ValueEqualPy(const py::object& other) const { bool Tensor::ValueEqualPy(const py::object &other) const {
if (!py::isinstance<Tensor>(other)) { if (!py::isinstance<Tensor>(other)) {
MS_LOG(WARNING) << "compare other not a tensor"; MS_LOG(WARNING) << "compare other not a tensor";
return false; return false;
@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const {
return ValueEqual(py::cast<Tensor>(other)); return ValueEqual(py::cast<Tensor>(other));
} }
bool Tensor::ValueEqual(const Tensor& other) const { bool Tensor::ValueEqual(const Tensor &other) const {
auto equal = [&other, this]() -> bool { auto equal = [&other, this]() -> bool {
auto np = py::module::import("numpy"); auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_); auto equal = np.attr("equal")(data_, other.data_);
@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast<int>(data_type_); }
std::vector<int> Tensor::shape_c(void) const { return shape(); } std::vector<int> Tensor::shape_c(void) const { return shape(); }
void* Tensor::data_c(bool writable) { void *Tensor::data_c(bool writable) {
// operand of bit operation should be unsigned int. // operand of bit operation should be unsigned int.
unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_;
bool is_c_contiguous = (flags != 0) ? true : false; bool is_c_contiguous = (flags != 0) ? true : false;
@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) {
return data_.request(writable).ptr; return data_.request(writable).ptr;
} }
TypeId Tensor::GetDataType(const py::buffer_info& buf) const { TypeId Tensor::GetDataType(const py::buffer_info &buf) const {
TypeId data_type = TypeId::kTypeUnknown; TypeId data_type = TypeId::kTypeUnknown;
if (buf.format.compare("e") == 0) { if (buf.format.compare("e") == 0) {
data_type = TypeId::kNumberTypeFloat16; data_type = TypeId::kNumberTypeFloat16;
@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const {
return data_type; return data_type;
} }
void Tensor::init(const py::array& input, const TypePtr& type_ptr) { void Tensor::init(const py::array &input, const TypePtr &type_ptr) {
TypeId data_type = TypeId::kTypeUnknown; TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) { if (type_ptr != nullptr) {
data_type = type_ptr->type_id(); data_type = type_ptr->type_id();
@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) {
init(input, data_type); init(input, data_type);
} }
void Tensor::init(const py::array& input, const TypeId& data_type) { void Tensor::init(const py::array &input, const TypeId &data_type) {
py::buffer_info buf = input.request(); py::buffer_info buf = input.request();
data_type_ = GetDataType(buf); data_type_ = GetDataType(buf);
@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) {
} }
} }
void Tensor::init(TypeId data_type, const std::vector<int>& shape, py::array* const data) { void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
data_type_ = data_type; data_type_ = data_type;
shape_ = shape; shape_ = shape;
switch (data_type) { switch (data_type) {
@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
return data_type_; return data_type_;
} }
bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out, bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out,
const TypeId out_data_type) { const TypeId out_data_type) {
if (out == nullptr) { if (out == nullptr) {
return false; return false;
@ -458,7 +458,7 @@ py::array Tensor::data_sync() {
return data_; return data_;
} }
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
// dtype should define before Tensor, because Tensor init depend dtype // dtype should define before Tensor, because Tensor init depend dtype
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor") (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor")
.def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape")) .def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape"))
@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) {
.def("__repr__", &Tensor::ToStringRepr) .def("__repr__", &Tensor::ToStringRepr)
.def("__eq__", &Tensor::ValueEqualPy) .def("__eq__", &Tensor::ValueEqualPy)
.def(py::pickle( .def(py::pickle(
[](const Tensor& t) { // __getstate__ [](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(t.data()); return py::make_tuple(t.data());
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }

View File

@ -131,16 +131,16 @@ class MetaTensor : public Value {
// information of a Tensor. The following codes will create a 2x3 float // information of a Tensor. The following codes will create a 2x3 float
// param data_type The data type of the tensor. // param data_type The data type of the tensor.
// param shape The shape of the tensor. // param shape The shape of the tensor.
MetaTensor(const TypeId data_type, const std::vector<int>& shape); MetaTensor(const TypeId data_type, const std::vector<int> &shape);
MetaTensor(const TypePtr& type_ptr, const py::tuple& shape); MetaTensor(const TypePtr &type_ptr, const py::tuple &shape);
// brief Constructs a MetaTensor object from an existing MetaTensor instance. // brief Constructs a MetaTensor object from an existing MetaTensor instance.
// //
// The constructed MetaTensor object will have the same data type and shape as the // The constructed MetaTensor object will have the same data type and shape as the
// meta_tensor. // meta_tensor.
// //
// param meta_tensor An existing MetaTensor object. // param meta_tensor An existing MetaTensor object.
MetaTensor(const MetaTensor& meta_tensor); MetaTensor(const MetaTensor &meta_tensor);
~MetaTensor() override = default; ~MetaTensor() override = default;
MS_DECLARE_PARENT(MetaTensor, Value) MS_DECLARE_PARENT(MetaTensor, Value)
@ -149,7 +149,7 @@ class MetaTensor : public Value {
// The constructed MetaTensor object has the same type and shape with meta_tensor. // The constructed MetaTensor object has the same type and shape with meta_tensor.
// //
// param meta_tensor An existing MetaTensor object. // param meta_tensor An existing MetaTensor object.
virtual MetaTensor& operator=(const MetaTensor& meta_tensor); virtual MetaTensor &operator=(const MetaTensor &meta_tensor);
// brief Compares two MetaTensor objects. // brief Compares two MetaTensor objects.
// //
@ -157,7 +157,7 @@ class MetaTensor : public Value {
// //
// param meta_tensor The MetaTensor object to be compared. // param meta_tensor The MetaTensor object to be compared.
// return true: If having same type and shape, return true, or return false. // return true: If having same type and shape, return true, or return false.
virtual bool operator==(const MetaTensor& meta_tensor) const; virtual bool operator==(const MetaTensor &meta_tensor) const;
// brief Returns the data type of the tensor in its MetaTensor. // brief Returns the data type of the tensor in its MetaTensor.
// //
@ -193,7 +193,7 @@ class MetaTensor : public Value {
// //
// param shape The shape of the tensor. // param shape The shape of the tensor.
// return The shape's size. // return The shape's size.
size_t set_shape(const std::vector<int>& shape) { size_t set_shape(const std::vector<int> &shape) {
this->shape_ = shape; this->shape_ = shape;
return shape_.size(); return shape_.size();
} }
@ -202,9 +202,9 @@ class MetaTensor : public Value {
DeviceInfo device_info() const { return device_info_; } DeviceInfo device_info() const { return device_info_; }
// Set tensor's device info. // Set tensor's device info.
void set_device_info(const DeviceInfo& device_info) { device_info_ = device_info; } void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; }
void SetDeviceInfo(const std::string& format, const TypePtr& data_type); void SetDeviceInfo(const std::string &format, const TypePtr &data_type);
// Get the size of a given dimension by its index number. // Get the size of a given dimension by its index number.
int DimensionSize(size_t index) const; int DimensionSize(size_t index) const;
@ -222,9 +222,9 @@ class MetaTensor : public Value {
} }
return hash_value; return hash_value;
} }
bool operator==(const Value& other) const override { bool operator==(const Value &other) const override {
if (other.isa<MetaTensor>()) { if (other.isa<MetaTensor>()) {
auto other_ = static_cast<const MetaTensor&>(other); auto other_ = static_cast<const MetaTensor &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
@ -262,49 +262,49 @@ class Tensor : public MetaTensor {
// //
// param type_ptr [TypePty] Data type of the tensor. // param type_ptr [TypePty] Data type of the tensor.
// param py_shape [py::tuple] The shape represented by py::tuple of the tensor. // param py_shape [py::tuple] The shape represented by py::tuple of the tensor.
Tensor(const TypePtr& type_ptr, const py::tuple& shape); Tensor(const TypePtr &type_ptr, const py::tuple &shape);
// brief Constructor for C++. // brief Constructor for C++.
// //
// param data_type [TypeId] Data type of the tensor. // param data_type [TypeId] Data type of the tensor.
// param shape The shape represented by std::vector<int> of the tensor. // param shape The shape represented by std::vector<int> of the tensor.
Tensor(TypeId data_type, const std::vector<int>& shape); Tensor(TypeId data_type, const std::vector<int> &shape);
// brief Constructor for Python. // brief Constructor for Python.
// //
// param input [py::array] Data value of the tensor. // param input [py::array] Data value of the tensor.
// param data_type [TypeId] Data type of the tensor. // param data_type [TypeId] Data type of the tensor.
explicit Tensor(const py::array& input, const TypePtr& data_type = nullptr); explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr);
// brief Constructor // brief Constructor
// //
// param input [py::list] the data for tensor // param input [py::list] the data for tensor
// param data_type [TypeId] data type // param data_type [TypeId] data type
explicit Tensor(const py::list& input, const TypePtr& data_type = nullptr); explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr);
// brief Constructor // brief Constructor
// //
// param input [py::tuple] the data for tensor // param input [py::tuple] the data for tensor
// param data_type [TypeId] data type // param data_type [TypeId] data type
explicit Tensor(const py::tuple& input, const TypePtr& data_type = nullptr); explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr);
// brief Constructor // brief Constructor
// //
// param input [py::float_] the data for tensor // param input [py::float_] the data for tensor
// param data_type [TypeId] data type // param data_type [TypeId] data type
explicit Tensor(const py::float_& input, const TypePtr& data_type = nullptr); explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr);
// brief Constructor // brief Constructor
// //
// param input [py::int_] the data for tensor // param input [py::int_] the data for tensor
// param data_type [TypeId] data type // param data_type [TypeId] data type
explicit Tensor(const py::int_& input, const TypePtr& data_type = nullptr); explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr);
// brief Constructor // brief Constructor
// //
// param input [Tensor] the data for tensor // param input [Tensor] the data for tensor
// param data_type [TypeId] data type // param data_type [TypeId] data type
Tensor(const Tensor& tensor, const TypePtr& data_type = nullptr); Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr);
~Tensor() override = default; ~Tensor() override = default;
@ -315,7 +315,7 @@ class Tensor : public MetaTensor {
// The constructed Tensor object has the same type and shape with tensor. // The constructed Tensor object has the same type and shape with tensor.
// //
// param tensor An existing Tensor object. // param tensor An existing Tensor object.
Tensor& operator=(const Tensor& tensor); Tensor &operator=(const Tensor &tensor);
// brief Compares two Tensor objects. // brief Compares two Tensor objects.
// //
@ -324,17 +324,17 @@ class Tensor : public MetaTensor {
// //
// param tensor The Tensor object to be compared. // param tensor The Tensor object to be compared.
// return true: If having same type, shape and data, return true, or return false. // return true: If having same type, shape and data, return true, or return false.
bool operator==(const Tensor& tensor) const; bool operator==(const Tensor &tensor) const;
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison. // It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqual(const Tensor& other) const; bool ValueEqual(const Tensor &other) const;
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison. // It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqualPy(const py::object& other) const; bool ValueEqualPy(const py::object &other) const;
bool operator==(const Value& other) const override { bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) { if (other.isa<Tensor>()) {
auto other_ = static_cast<const Tensor&>(other); auto other_ = static_cast<const Tensor &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
@ -375,13 +375,13 @@ class Tensor : public MetaTensor {
// //
// param writable true if writable, false if read only // param writable true if writable, false if read only
// return The pointer to the object // return The pointer to the object
void* data_c(bool writable = false); void *data_c(bool writable = false);
// brief Get data type from tensor data. // brief Get data type from tensor data.
// //
// param buf The buffer info of the py::array data. // param buf The buffer info of the py::array data.
// return The [TypeId] of the tensor data. // return The [TypeId] of the tensor data.
TypeId GetDataType(const py::buffer_info& buf) const; TypeId GetDataType(const py::buffer_info &buf) const;
// brief Sets the data type of a tensor. // brief Sets the data type of a tensor.
// //
@ -401,23 +401,23 @@ class Tensor : public MetaTensor {
// param input [py::array] the data for tensor // param input [py::array] the data for tensor
// param data_type [TypeId] data type // param data_type [TypeId] data type
// return true if succeed, false if failed. // return true if succeed, false if failed.
void init(const py::array& input, const TypeId& data_type); void init(const py::array &input, const TypeId &data_type);
void init(const py::array& input, const TypePtr& type_ptr); void init(const py::array &input, const TypePtr &type_ptr);
// brief init tensor attribute // brief init tensor attribute
// //
// param data_type [TypeId] Data type of the tensor. // param data_type [TypeId] Data type of the tensor.
// param shape [py::array] The shape of the tensor. // param shape [py::array] The shape of the tensor.
// return true if succeed, false if failed. // return true if succeed, false if failed.
void init(TypeId data_type, const std::vector<int>& shape, py::array* data); void init(TypeId data_type, const std::vector<int> &shape, py::array *data);
bool convert_data(const py::array& in, const TypeId in_data_type, py::array* out, const TypeId out_data_type); bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type);
public: public:
bool is_dirty() const { return dirty_; } bool is_dirty() const { return dirty_; }
void set_dirty(const bool dirty) { dirty_ = dirty; } void set_dirty(const bool dirty) { dirty_ = dirty; }
DeviceAddressPtr device_address() const { return device_address_; } DeviceAddressPtr device_address() const { return device_address_; }
void set_device_address(const DeviceAddressPtr& device_address) { device_address_ = device_address; } void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
py::array data_sync(); py::array data_sync();
private: private:

View File

@ -18,9 +18,9 @@
#include "pipeline/static_analysis/abstract_value.h" #include "pipeline/static_analysis/abstract_value.h"
namespace mindspore { namespace mindspore {
bool Named::operator==(const Value& other) const { bool Named::operator==(const Value &other) const {
if (other.isa<Named>()) { if (other.isa<Named>()) {
auto other_named = static_cast<const Named&>(other); auto other_named = static_cast<const Named &>(other);
return *this == other_named; return *this == other_named;
} else { } else {
return false; return false;

View File

@ -27,18 +27,18 @@
namespace mindspore { namespace mindspore {
class Named : public Value { class Named : public Value {
public: public:
explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); } explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
Named(const Named& other) : Value(other) { Named(const Named &other) : Value(other) {
this->name_ = other.name_; this->name_ = other.name_;
hash_id_ = std::hash<std::string>{}(other.name_); hash_id_ = std::hash<std::string>{}(other.name_);
} }
~Named() override = default; ~Named() override = default;
MS_DECLARE_PARENT(Named, Value); MS_DECLARE_PARENT(Named, Value);
const std::string& name() const { return name_; } const std::string &name() const { return name_; }
virtual bool operator==(const Named& other) const { return name_ == other.name(); } virtual bool operator==(const Named &other) const { return name_ == other.name(); }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
Named& operator=(const Named& other) { Named &operator=(const Named &other) {
if (&other != this) { if (&other != this) {
this->type_ = other.type_; this->type_ = other.type_;
this->name_ = other.name_; this->name_ = other.name_;
@ -50,7 +50,7 @@ class Named : public Value {
std::size_t Hash() const { return hash_id_; } std::size_t Hash() const { return hash_id_; }
std::size_t hash() const override { return hash_id_; } std::size_t hash() const override { return hash_id_; }
friend std::ostream& operator<<(std::ostream& os, const Named& nmd) { friend std::ostream &operator<<(std::ostream &os, const Named &nmd) {
os << nmd.name(); os << nmd.name();
return os; return os;
} }

View File

@ -31,7 +31,7 @@
namespace mindspore { namespace mindspore {
using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunction;
abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr& anf_node) { abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) {
auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node); auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node);
return prim_func; return prim_func;
} }
@ -63,23 +63,23 @@ py::function Primitive::GetComputeFunction() {
return fn; return fn;
} }
bool Primitive::operator==(const Value& other) const { bool Primitive::operator==(const Value &other) const {
if (other.isa<Primitive>()) { if (other.isa<Primitive>()) {
auto other_prim = static_cast<const Primitive&>(other); auto other_prim = static_cast<const Primitive &>(other);
return *this == other_prim; return *this == other_prim;
} else { } else {
return false; return false;
} }
} }
bool Primitive::operator==(const Primitive& other) const { bool Primitive::operator==(const Primitive &other) const {
if (name() != other.name()) { if (name() != other.name()) {
return false; return false;
} }
if (attrs_.size() != other.attrs_.size()) { if (attrs_.size() != other.attrs_.size()) {
return false; return false;
} }
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr>& item) -> bool { auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
if (item.second == nullptr) { if (item.second == nullptr) {
return false; return false;
} }
@ -95,7 +95,7 @@ bool Primitive::operator==(const Primitive& other) const {
void Primitive::set_signatures( void Primitive::set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
signatures_.clear(); signatures_.clear();
for (auto& signature : signatures) { for (auto &signature : signatures) {
std::string name; std::string name;
SignatureEnumRW rw; SignatureEnumRW rw;
SignatureEnumKind kind; SignatureEnumKind kind;
@ -114,7 +114,7 @@ std::string Primitive::GetAttrsText() const {
std::ostringstream oss; std::ostringstream oss;
oss << "["; oss << "[";
bool is_first = true; bool is_first = true;
for (auto& attr : attrs_) { for (auto &attr : attrs_) {
if (is_first) { if (is_first) {
is_first = false; is_first = false;
} else { } else {
@ -128,7 +128,7 @@ std::string Primitive::GetAttrsText() const {
} }
py::function PrimitivePy::GetBpropFunction() { py::function PrimitivePy::GetBpropFunction() {
static const char* const get_bprop_func_name = "get_bprop"; static const char *const get_bprop_func_name = "get_bprop";
if (py::hasattr(python_obj_, get_bprop_func_name)) { if (py::hasattr(python_obj_, get_bprop_func_name)) {
py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>(); py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
return fn; return fn;
@ -142,7 +142,7 @@ py::function PrimitivePy::GetBpropFunction() {
} }
py::function PrimitivePy::GetComputeFunction() { py::function PrimitivePy::GetComputeFunction() {
static const char* const compute_func_name = "vm_impl"; static const char *const compute_func_name = "vm_impl";
if (py::hasattr(python_obj_, compute_func_name)) { if (py::hasattr(python_obj_, compute_func_name)) {
MS_LOG(INFO) << "" << name() << " compute_func_name"; MS_LOG(INFO) << "" << name() << " compute_func_name";
@ -163,7 +163,7 @@ py::function PrimitivePy::GetComputeFunction() {
return vm_fn; return vm_fn;
} }
void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
std::string attr_name = name; std::string attr_name = name;
ValuePtr converted_ret = nullptr; ValuePtr converted_ret = nullptr;
if (py::isinstance<py::module>(obj)) { if (py::isinstance<py::module>(obj)) {
@ -178,13 +178,13 @@ void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) {
py::dict PrimitivePy::GetAttrDict() { py::dict PrimitivePy::GetAttrDict() {
py::dict attr_dict; py::dict attr_dict;
for (auto& attr : attrs_) { for (auto &attr : attrs_) {
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
} }
return attr_dict; return attr_dict;
} }
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown) .value("unknown", PrimType::kPrimTypeUnknown)
.value("builtin", PrimType::kPrimTypeBuiltIn) .value("builtin", PrimType::kPrimTypeBuiltIn)
@ -192,7 +192,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) {
.value("user_custom", PrimType::kPrimTypeUserCustom); .value("user_custom", PrimType::kPrimTypeUserCustom);
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
.def(py::init<py::str&, py::object>()) .def(py::init<py::str &, py::object>())
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")

View File

@ -48,25 +48,25 @@ enum PrimType {
class Primitive : public Named { class Primitive : public Named {
public: public:
explicit Primitive(const std::string& name, const PrimType prim_type = kPrimTypeBuiltIn) explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn)
: Named(name), signatures_(), prim_type_(prim_type) {} : Named(name), signatures_(), prim_type_(prim_type) {}
Primitive(const Primitive& prim) Primitive(const Primitive &prim)
: Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {}
MS_DECLARE_PARENT(Primitive, Named); MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr& anf_node); abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); } std::string ToString() const override { return name(); }
virtual py::function GetBpropFunction(); virtual py::function GetBpropFunction();
virtual py::function GetComputeFunction(); virtual py::function GetComputeFunction();
Primitive& AddAttr(const std::string& name, const ValuePtr& attr) { Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
attrs_[name] = attr; attrs_[name] = attr;
return *this; return *this;
} }
Primitive& SetAttrs(const std::unordered_map<std::string, ValuePtr>& attrs) { Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto& attr : attrs) { for (auto &attr : attrs) {
attrs_[attr.first] = attr.second; attrs_[attr.first] = attr.second;
} }
return *this; return *this;
@ -76,21 +76,21 @@ class Primitive : public Named {
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
signatures); signatures);
const std::vector<Signature>& signatures() const { return signatures_; } const std::vector<Signature> &signatures() const { return signatures_; }
void set_attr(const std::string& attrName, const ValuePtr& attr) { attrs_[attrName] = attr; } void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
void EraseAttr(const std::string& attrName) { (void)attrs_.erase(attrName); } void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
ValuePtr GetAttr(const std::string& attrName) const { ValuePtr GetAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName); auto iter = attrs_.find(attrName);
return iter == attrs_.cend() ? nullptr : iter->second; return iter == attrs_.cend() ? nullptr : iter->second;
} }
const std::unordered_map<std::string, ValuePtr>& attrs() const { return attrs_; } const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool HasAttr() const { return !attrs_.empty(); } bool HasAttr() const { return !attrs_.empty(); }
bool HasAttr(const std::string& attrName) const { bool HasAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName); auto iter = attrs_.find(attrName);
return !(iter == attrs_.cend()); return !(iter == attrs_.cend());
} }
@ -103,8 +103,8 @@ class Primitive : public Named {
PrimType prim_type() const { return prim_type_; } PrimType prim_type() const { return prim_type_; }
std::string instance_name() const { return instance_name_; } std::string instance_name() const { return instance_name_; }
std::string GetAttrsText() const; std::string GetAttrsText() const;
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const Primitive& other) const; bool operator==(const Primitive &other) const;
~Primitive() override = default; ~Primitive() override = default;
protected: protected:
@ -118,18 +118,18 @@ class Primitive : public Named {
class PrimitivePy : public Primitive { class PrimitivePy : public Primitive {
public: public:
PrimitivePy(const py::str& name, const py::object& python_obj) : Primitive(name), python_obj_(python_obj) {} PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {}
~PrimitivePy() override = default; ~PrimitivePy() override = default;
MS_DECLARE_PARENT(PrimitivePy, Primitive); MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction() override; py::function GetBpropFunction() override;
py::function GetComputeFunction() override; py::function GetComputeFunction() override;
void AddPyAttr(const py::str& name, const py::object& obj); void AddPyAttr(const py::str &name, const py::object &obj);
py::dict GetAttrDict(); py::dict GetAttrDict();
const bool parse_info_ = true; const bool parse_info_ = true;
const py::object& GetPyObj() const { return python_obj_; } const py::object &GetPyObj() const { return python_obj_; }
bool is_tuple_input_ = false; bool is_tuple_input_ = false;
private: private:
@ -138,13 +138,13 @@ class PrimitivePy : public Primitive {
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
inline std::ostream& operator<<(std::ostream& os, const PrimitivePtr& p) { inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
os << *p; os << *p;
return os; return os;
} }
struct PrimitiveEqual { struct PrimitiveEqual {
bool operator()(PrimitivePtr const& t1, PrimitivePtr const& t2) const { bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2); MS_EXCEPTION_IF_NULL(t2);
return t1->name() == t2->name(); return t1->name() == t2->name();
@ -152,7 +152,7 @@ struct PrimitiveEqual {
}; };
struct PrimitiveHasher { struct PrimitiveHasher {
std::size_t operator()(PrimitivePtr const& prim) const { std::size_t operator()(PrimitivePtr const &prim) const {
std::size_t hash = std::hash<std::string>()(prim->name()); std::size_t hash = std::hash<std::string>()(prim->name());
return hash; return hash;
} }

View File

@ -55,8 +55,8 @@ class BoolImm : public Scalar {
bool value() const { return v_; } bool value() const { return v_; }
bool IsZero() override { return v_ == false; } bool IsZero() override { return v_ == false; }
bool IsOne() override { return v_ == true; } bool IsOne() override { return v_ == true; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const BoolImm& other) const; bool operator==(const BoolImm &other) const;
std::string ToString() const override { std::string ToString() const override {
if (v_) { if (v_) {
return "true"; return "true";
@ -80,7 +80,7 @@ IMM_TRAITS(BoolImmPtr, bool)
class IntergerImm : public Scalar { class IntergerImm : public Scalar {
public: public:
IntergerImm() = default; IntergerImm() = default;
explicit IntergerImm(const TypePtr& t) : Scalar(t) {} explicit IntergerImm(const TypePtr &t) : Scalar(t) {}
~IntergerImm() override = default; ~IntergerImm() override = default;
MS_DECLARE_PARENT(IntergerImm, Scalar) MS_DECLARE_PARENT(IntergerImm, Scalar)
}; };
@ -95,8 +95,8 @@ class Int8Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; } bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; } bool IsOne() override { return v_ == 1; }
int8_t value() const { return v_; } int8_t value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const Int8Imm& other) const; bool operator==(const Int8Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -121,8 +121,8 @@ class Int16Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; } bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; } bool IsOne() override { return v_ == 1; }
int16_t value() const { return v_; } int16_t value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const Int16Imm& other) const; bool operator==(const Int16Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -147,8 +147,8 @@ class Int32Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; } bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; } bool IsOne() override { return v_ == 1; }
int32_t value() const { return v_; } int32_t value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const Int32Imm& other) const; bool operator==(const Int32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -173,8 +173,8 @@ class Int64Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; } bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; } bool IsOne() override { return v_ == 1; }
int64_t value() const { return v_; } int64_t value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const Int64Imm& other) const; bool operator==(const Int64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -199,8 +199,8 @@ class UInt8Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; } bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; } bool IsOne() override { return v_ == 1; }
uint8_t value() const { return v_; } uint8_t value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const UInt8Imm& other) const; bool operator==(const UInt8Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -225,8 +225,8 @@ class UInt16Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; } bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; } bool IsOne() override { return v_ == 1; }
uint16_t value() const { return v_; } uint16_t value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const UInt16Imm& other) const; bool operator==(const UInt16Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -251,8 +251,8 @@ class UInt32Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; } bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; } bool IsOne() override { return v_ == 1; }
uint32_t value() const { return v_; } uint32_t value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const UInt32Imm& other) const; bool operator==(const UInt32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -277,8 +277,8 @@ class UInt64Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; } bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; } bool IsOne() override { return v_ == 1; }
uint64_t value() const { return v_; } uint64_t value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const UInt64Imm& other) const; bool operator==(const UInt64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -296,7 +296,7 @@ IMM_TRAITS(UInt64ImmPtr, uint64_t);
class FloatImm : public Scalar { class FloatImm : public Scalar {
public: public:
FloatImm() = default; FloatImm() = default;
explicit FloatImm(const TypePtr& t) : Scalar(t) {} explicit FloatImm(const TypePtr &t) : Scalar(t) {}
~FloatImm() override = default; ~FloatImm() override = default;
MS_DECLARE_PARENT(FloatImm, Scalar) MS_DECLARE_PARENT(FloatImm, Scalar)
}; };
@ -312,8 +312,8 @@ class FP32Imm : public FloatImm {
bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
float value() const { return v_; } float value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const FP32Imm& other) const; bool operator==(const FP32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {
@ -338,8 +338,8 @@ class FP64Imm : public FloatImm {
bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
double value() const { return v_; } double value() const { return v_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const FP64Imm& other) const; bool operator==(const FP64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); } std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override { std::string DumpText() const override {

View File

@ -21,8 +21,8 @@
#include "pipeline/parse/data_converter.h" #include "pipeline/parse/data_converter.h"
namespace mindspore { namespace mindspore {
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object& arg_default, const SignatureEnumDType& arg_dtype) const py::object &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
if (py::isinstance<SignatureEnumKind>(arg_default) && if (py::isinstance<SignatureEnumKind>(arg_default) &&
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag,
} }
} }
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind) Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: name(arg_name), : name(arg_name),
rw(rw_tag), rw(rw_tag),
kind(arg_kind), kind(arg_kind),
default_value(nullptr), default_value(nullptr),
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic()) (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead) .value("RW_READ", SignatureEnumRW::kRWRead)
.value("RW_WRITE", SignatureEnumRW::kRWWrite) .value("RW_WRITE", SignatureEnumRW::kRWWrite)

View File

@ -61,9 +61,9 @@ struct Signature {
SignatureEnumKind kind; SignatureEnumKind kind;
ValuePtr default_value; // nullptr for no default value ValuePtr default_value; // nullptr for no default value
SignatureEnumDType dtype; SignatureEnumDType dtype;
Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object& arg_default, const SignatureEnumDType& arg_dtype); const py::object &arg_default, const SignatureEnumDType &arg_dtype);
Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind); Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind);
}; };
} // namespace mindspore } // namespace mindspore

View File

@ -24,7 +24,7 @@
#include "pipeline/static_analysis/abstract_value.h" #include "pipeline/static_analysis/abstract_value.h"
namespace mindspore { namespace mindspore {
const ValuePtr ValueSequeue::operator[](const std::size_t& dim) const { const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const {
if (dim >= size()) { if (dim >= size()) {
MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "]."; MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "].";
} }
@ -40,125 +40,125 @@ bool ValueSequeue::erase(size_t idx) {
} }
} }
bool BoolImm::operator==(const Value& other) const { bool BoolImm::operator==(const Value &other) const {
if (other.isa<BoolImm>()) { if (other.isa<BoolImm>()) {
auto other_ = static_cast<const BoolImm&>(other); auto other_ = static_cast<const BoolImm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool BoolImm::operator==(const BoolImm& other) const { return v_ == other.v_; } bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; }
bool Int8Imm::operator==(const Value& other) const { bool Int8Imm::operator==(const Value &other) const {
if (other.isa<Int8Imm>()) { if (other.isa<Int8Imm>()) {
auto other_ = static_cast<const Int8Imm&>(other); auto other_ = static_cast<const Int8Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool Int8Imm::operator==(const Int8Imm& other) const { return v_ == other.v_; } bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; }
bool Int16Imm::operator==(const Value& other) const { bool Int16Imm::operator==(const Value &other) const {
if (other.isa<Int16Imm>()) { if (other.isa<Int16Imm>()) {
auto other_ = static_cast<const Int16Imm&>(other); auto other_ = static_cast<const Int16Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool Int16Imm::operator==(const Int16Imm& other) const { return v_ == other.v_; } bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; }
bool Int32Imm::operator==(const Value& other) const { bool Int32Imm::operator==(const Value &other) const {
if (other.isa<Int32Imm>()) { if (other.isa<Int32Imm>()) {
auto other_ = static_cast<const Int32Imm&>(other); auto other_ = static_cast<const Int32Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool Int32Imm::operator==(const Int32Imm& other) const { return v_ == other.v_; } bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; }
bool Int64Imm::operator==(const Value& other) const { bool Int64Imm::operator==(const Value &other) const {
if (other.isa<Int64Imm>()) { if (other.isa<Int64Imm>()) {
auto other_ = static_cast<const Int64Imm&>(other); auto other_ = static_cast<const Int64Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool Int64Imm::operator==(const Int64Imm& other) const { return v_ == other.v_; } bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; }
bool UInt8Imm::operator==(const Value& other) const { bool UInt8Imm::operator==(const Value &other) const {
if (other.isa<UInt8Imm>()) { if (other.isa<UInt8Imm>()) {
auto other_ = static_cast<const UInt8Imm&>(other); auto other_ = static_cast<const UInt8Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool UInt8Imm::operator==(const UInt8Imm& other) const { return v_ == other.v_; } bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; }
bool UInt16Imm::operator==(const Value& other) const { bool UInt16Imm::operator==(const Value &other) const {
if (other.isa<UInt16Imm>()) { if (other.isa<UInt16Imm>()) {
auto other_ = static_cast<const UInt16Imm&>(other); auto other_ = static_cast<const UInt16Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool UInt16Imm::operator==(const UInt16Imm& other) const { return v_ == other.v_; } bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; }
bool UInt32Imm::operator==(const Value& other) const { bool UInt32Imm::operator==(const Value &other) const {
if (other.isa<UInt32Imm>()) { if (other.isa<UInt32Imm>()) {
auto other_ = static_cast<const UInt32Imm&>(other); auto other_ = static_cast<const UInt32Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool UInt32Imm::operator==(const UInt32Imm& other) const { return v_ == other.v_; } bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; }
bool UInt64Imm::operator==(const Value& other) const { bool UInt64Imm::operator==(const Value &other) const {
if (other.isa<UInt64Imm>()) { if (other.isa<UInt64Imm>()) {
auto other_ = static_cast<const UInt64Imm&>(other); auto other_ = static_cast<const UInt64Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool UInt64Imm::operator==(const UInt64Imm& other) const { return v_ == other.v_; } bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; }
bool FP32Imm::operator==(const Value& other) const { bool FP32Imm::operator==(const Value &other) const {
if (other.isa<FP32Imm>()) { if (other.isa<FP32Imm>()) {
auto other_ = static_cast<const FP32Imm&>(other); auto other_ = static_cast<const FP32Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool FP32Imm::operator==(const FP32Imm& other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; }
bool FP64Imm::operator==(const Value& other) const { bool FP64Imm::operator==(const Value &other) const {
if (other.isa<FP64Imm>()) { if (other.isa<FP64Imm>()) {
auto other_ = static_cast<const FP64Imm&>(other); auto other_ = static_cast<const FP64Imm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool ValueSequeue::operator==(const Value& other) const { bool ValueSequeue::operator==(const Value &other) const {
if (other.isa<ValueSequeue>()) { if (other.isa<ValueSequeue>()) {
auto other_ = static_cast<const ValueSequeue&>(other); auto other_ = static_cast<const ValueSequeue &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool ValueSequeue::operator==(const ValueSequeue& other) const { bool ValueSequeue::operator==(const ValueSequeue &other) const {
if (other.elements_.size() != elements_.size()) { if (other.elements_.size() != elements_.size()) {
return false; return false;
} }
return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(), return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(),
[](const ValuePtr& lhs, const ValuePtr& rhs) { return *lhs == *rhs; }); [](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; });
} }
std::string ValueSequeue::ToString() const { std::string ValueSequeue::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
bool begin = true; bool begin = true;
for (auto& attr : elements_) { for (auto &attr : elements_) {
if (!begin) { if (!begin) {
buffer << ", "; buffer << ", ";
} else { } else {
@ -179,28 +179,28 @@ std::string ValueSequeue::DumpText() const {
return oss.str(); return oss.str();
} }
bool FP64Imm::operator==(const FP64Imm& other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; }
bool StringImm::operator==(const Value& other) const { bool StringImm::operator==(const Value &other) const {
if (other.isa<StringImm>()) { if (other.isa<StringImm>()) {
auto other_ = static_cast<const StringImm&>(other); auto other_ = static_cast<const StringImm &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool StringImm::operator==(const StringImm& other) const { return str_ == other.str_; } bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; }
bool RefKey::operator==(const Value& other) const { bool RefKey::operator==(const Value &other) const {
if (other.isa<RefKey>()) { if (other.isa<RefKey>()) {
auto other_ = static_cast<const RefKey&>(other); auto other_ = static_cast<const RefKey &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool RefKey::operator==(const RefKey& other) const { return tag_ == other.tag_; } bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; }
bool AnyValue::operator==(const Value& other) const { bool AnyValue::operator==(const Value &other) const {
if (other.isa<AnyValue>()) { if (other.isa<AnyValue>()) {
return true; return true;
} else { } else {
@ -228,7 +228,7 @@ abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstr
abstract::AbstractBasePtr ValueTuple::ToAbstract() { abstract::AbstractBasePtr ValueTuple::ToAbstract() {
abstract::AbstractBasePtrList a_list; abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele); MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract(); return ele->ToAbstract();
}); });
@ -237,7 +237,7 @@ abstract::AbstractBasePtr ValueTuple::ToAbstract() {
abstract::AbstractBasePtr ValueList::ToAbstract() { abstract::AbstractBasePtr ValueList::ToAbstract() {
abstract::AbstractBasePtrList a_list; abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele); MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract(); return ele->ToAbstract();
}); });
@ -251,16 +251,16 @@ std::size_t ValueSlice::hash() const {
return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
} }
bool ValueSlice::operator==(const Value& other) const { bool ValueSlice::operator==(const Value &other) const {
if (other.isa<ValueSlice>()) { if (other.isa<ValueSlice>()) {
auto other_ = static_cast<const ValueSlice&>(other); auto other_ = static_cast<const ValueSlice &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool ValueSlice::operator==(const ValueSlice& other) const { bool ValueSlice::operator==(const ValueSlice &other) const {
MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(start_);
MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(stop_);
MS_EXCEPTION_IF_NULL(step_); MS_EXCEPTION_IF_NULL(step_);
@ -295,16 +295,16 @@ std::size_t KeywordArg::hash() const {
return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()}); return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()});
} }
bool KeywordArg::operator==(const Value& other) const { bool KeywordArg::operator==(const Value &other) const {
if (other.isa<KeywordArg>()) { if (other.isa<KeywordArg>()) {
auto other_ = static_cast<const KeywordArg&>(other); auto other_ = static_cast<const KeywordArg &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool KeywordArg::operator==(const KeywordArg& other) const { return (other.key_ == key_ && *other.value_ == *value_); } bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); }
std::string KeywordArg::ToString() const { std::string KeywordArg::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
@ -322,25 +322,25 @@ abstract::AbstractBasePtr KeywordArg::ToAbstract() {
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument); return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
} }
const ValuePtr ValueDictionary::operator[](const std::string& key) const { const ValuePtr ValueDictionary::operator[](const std::string &key) const {
auto it = std::find_if(key_values_.begin(), key_values_.end(), auto it = std::find_if(key_values_.begin(), key_values_.end(),
[key](const std::pair<std::string, ValuePtr>& item) { return item.first == key; }); [key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; });
if (it == key_values_.end()) { if (it == key_values_.end()) {
MS_LOG(EXCEPTION) << "The key " << key << " is not in the map"; MS_LOG(EXCEPTION) << "The key " << key << " is not in the map";
} }
return it->second; return it->second;
} }
bool ValueDictionary::operator==(const Value& other) const { bool ValueDictionary::operator==(const Value &other) const {
if (other.isa<ValueDictionary>()) { if (other.isa<ValueDictionary>()) {
auto other_ = static_cast<const ValueDictionary&>(other); auto other_ = static_cast<const ValueDictionary &>(other);
return *this == other_; return *this == other_;
} else { } else {
return false; return false;
} }
} }
bool ValueDictionary::operator==(const ValueDictionary& other) const { bool ValueDictionary::operator==(const ValueDictionary &other) const {
if (key_values_.size() != other.key_values_.size()) { if (key_values_.size() != other.key_values_.size()) {
return false; return false;
} }
@ -359,12 +359,12 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() {
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv; std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
(void)std::transform( (void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv), key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, ValuePtr>& item) { return std::make_pair(item.first, item.second->ToAbstract()); }); [](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); });
return std::make_shared<abstract::AbstractDictionary>(kv); return std::make_shared<abstract::AbstractDictionary>(kv);
} }
REGISTER_PYBIND_DEFINE( REGISTER_PYBIND_DEFINE(
RefKey, ([](const py::module* m) { RefKey, ([](const py::module *m) {
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag")); (void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
})); }));
} // namespace mindspore } // namespace mindspore

View File

@ -35,19 +35,19 @@
namespace mindspore { namespace mindspore {
class ValueSequeue : public Value { class ValueSequeue : public Value {
public: public:
explicit ValueSequeue(const ValuePtrList& elements) : elements_(elements) { explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) {
TypePtrList t_list; TypePtrList t_list;
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr& ele) { (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele); MS_EXCEPTION_IF_NULL(ele);
return ele->type(); return ele->type();
}); });
TypePtr t = std::make_shared<Tuple>(t_list); TypePtr t = std::make_shared<Tuple>(t_list);
type_ = t; type_ = t;
} }
ValueSequeue(const std::initializer_list<ValuePtr>& elements) : elements_(elements.begin(), elements.end()) { ValueSequeue(const std::initializer_list<ValuePtr> &elements) : elements_(elements.begin(), elements.end()) {
TypePtrList t_list; TypePtrList t_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list),
[](const ValuePtr& ele) { return ele->type(); }); [](const ValuePtr &ele) { return ele->type(); });
TypePtr t = std::make_shared<Tuple>(t_list); TypePtr t = std::make_shared<Tuple>(t_list);
type_ = t; type_ = t;
} }
@ -56,10 +56,10 @@ class ValueSequeue : public Value {
std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); } std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); }
std::size_t size() const { return elements_.size(); } std::size_t size() const { return elements_.size(); }
bool erase(size_t idx); bool erase(size_t idx);
const ValuePtr operator[](const std::size_t& dim) const; const ValuePtr operator[](const std::size_t &dim) const;
const ValuePtrList& value() const { return elements_; } const ValuePtrList &value() const { return elements_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const ValueSequeue& other) const; bool operator==(const ValueSequeue &other) const;
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
@ -70,8 +70,8 @@ using ValueSequeuePtr = std::shared_ptr<ValueSequeue>;
class ValueTuple : public ValueSequeue { class ValueTuple : public ValueSequeue {
public: public:
explicit ValueTuple(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {} explicit ValueTuple(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
ValueTuple(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {} ValueTuple(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
~ValueTuple() override = default; ~ValueTuple() override = default;
MS_DECLARE_PARENT(ValueTuple, ValueSequeue) MS_DECLARE_PARENT(ValueTuple, ValueSequeue)
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
@ -83,8 +83,8 @@ using ValueTuplePtr = std::shared_ptr<ValueTuple>;
class ValueList : public ValueSequeue { class ValueList : public ValueSequeue {
public: public:
explicit ValueList(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {} explicit ValueList(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
ValueList(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {} ValueList(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
~ValueList() override = default; ~ValueList() override = default;
MS_DECLARE_PARENT(ValueList, ValueSequeue) MS_DECLARE_PARENT(ValueList, ValueSequeue)
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
@ -94,7 +94,7 @@ class ValueList : public ValueSequeue {
}; };
using ValueListPtr = std::shared_ptr<ValueList>; using ValueListPtr = std::shared_ptr<ValueList>;
inline ValuePtr MakeValue(const std::vector<ValuePtr>& v) { return std::make_shared<ValueTuple>(v); } inline ValuePtr MakeValue(const std::vector<ValuePtr> &v) { return std::make_shared<ValueTuple>(v); }
inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); } inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); }
template <typename T> template <typename T>
@ -103,7 +103,7 @@ template <typename T, typename A>
struct is_vector<std::vector<T, A>> : public std::true_type {}; struct is_vector<std::vector<T, A>> : public std::true_type {};
template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type> template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type>
ValuePtr MakeValue(const T& vec) { ValuePtr MakeValue(const T &vec) {
std::vector<ValuePtr> list; std::vector<ValuePtr> list;
(void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); });
return std::make_shared<ValueTuple>(list); return std::make_shared<ValueTuple>(list);
@ -111,13 +111,13 @@ ValuePtr MakeValue(const T& vec) {
class ValueSlice : public Value { class ValueSlice : public Value {
public: public:
ValueSlice(const ValuePtr& start, const ValuePtr& stop, const ValuePtr& step) ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step)
: start_(start), stop_(stop), step_(step) {} : start_(start), stop_(stop), step_(step) {}
~ValueSlice() override = default; ~ValueSlice() override = default;
MS_DECLARE_PARENT(ValueSlice, Value) MS_DECLARE_PARENT(ValueSlice, Value)
std::size_t hash() const override; std::size_t hash() const override;
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const ValueSlice& other) const; bool operator==(const ValueSlice &other) const;
std::string ToString() const override; std::string ToString() const override;
@ -133,13 +133,13 @@ using ValueSlicePtr = std::shared_ptr<ValueSlice>;
class KeywordArg : public Value { class KeywordArg : public Value {
public: public:
KeywordArg(const std::string& key, const ValuePtr& value) : key_(key), value_(value) {} KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {}
~KeywordArg() override = default; ~KeywordArg() override = default;
MS_DECLARE_PARENT(KeywordArg, Value) MS_DECLARE_PARENT(KeywordArg, Value)
std::size_t hash() const override; std::size_t hash() const override;
ValuePtr get_value() const { return value_; } ValuePtr get_value() const { return value_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const KeywordArg& other) const; bool operator==(const KeywordArg &other) const;
std::string ToString() const override; std::string ToString() const override;
@ -154,31 +154,31 @@ using KeywordArgPtr = std::shared_ptr<KeywordArg>;
class ValueDictionary : public Value { class ValueDictionary : public Value {
public: public:
explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>>& key_values) : key_values_(key_values) {} explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>> &key_values) : key_values_(key_values) {}
~ValueDictionary() override = default; ~ValueDictionary() override = default;
MS_DECLARE_PARENT(ValueDictionary, Value) MS_DECLARE_PARENT(ValueDictionary, Value)
std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); } std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); }
std::size_t size() const { return key_values_.size(); } std::size_t size() const { return key_values_.size(); }
const ValuePtr operator[](const std::string& key) const; const ValuePtr operator[](const std::string &key) const;
const std::vector<std::pair<std::string, ValuePtr>>& value() const { return key_values_; } const std::vector<std::pair<std::string, ValuePtr>> &value() const { return key_values_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const ValueDictionary& other) const; bool operator==(const ValueDictionary &other) const;
std::string ToString() const override { std::string ToString() const override {
std::ostringstream buffer; std::ostringstream buffer;
std::vector<std::string> keys; std::vector<std::string> keys;
std::vector<ValuePtr> values; std::vector<ValuePtr> values;
for (const auto& kv : key_values_) { for (const auto &kv : key_values_) {
keys.push_back(kv.first); keys.push_back(kv.first);
values.push_back(kv.second); values.push_back(kv.second);
} }
buffer << "(Dict: " buffer << "(Dict: "
<< " keys:("; << " keys:(";
for (const auto& key : keys) { for (const auto &key : keys) {
buffer << key << ", "; buffer << key << ", ";
} }
buffer << ") values:("; buffer << ") values:(";
for (const auto& value : values) { for (const auto &value : values) {
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
buffer << value->DumpText() << ", "; buffer << value->DumpText() << ", ";
} }
@ -195,14 +195,14 @@ using ValueDictionaryPtr = std::shared_ptr<ValueDictionary>;
class StringImm : public Value { class StringImm : public Value {
public: public:
explicit StringImm(const std::string& str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {} explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {}
~StringImm() override = default; ~StringImm() override = default;
MS_DECLARE_PARENT(StringImm, Value) MS_DECLARE_PARENT(StringImm, Value)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
const std::string& value() const { return str_; } const std::string &value() const { return str_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const StringImm& other) const; bool operator==(const StringImm &other) const;
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
std::string ToString() const override { return str_; } std::string ToString() const override { return str_; }
@ -218,18 +218,18 @@ class StringImm : public Value {
}; };
using StringImmPtr = std::shared_ptr<StringImm>; using StringImmPtr = std::shared_ptr<StringImm>;
IMM_TRAITS(StringImmPtr, std::string) IMM_TRAITS(StringImmPtr, std::string)
IMM_TRAITS(StringImmPtr, const char*) IMM_TRAITS(StringImmPtr, const char *)
class RefKey : public Value { class RefKey : public Value {
public: public:
explicit RefKey(const std::string& tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {} explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {}
~RefKey() override = default; ~RefKey() override = default;
MS_DECLARE_PARENT(RefKey, Value) MS_DECLARE_PARENT(RefKey, Value)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
const std::string& tag() const { return tag_; } const std::string &tag() const { return tag_; }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
bool operator==(const RefKey& other) const; bool operator==(const RefKey &other) const;
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
std::string ToString() const override { return "RefKey[" + tag_ + "]"; } std::string ToString() const override { return "RefKey[" + tag_ + "]"; }
@ -251,13 +251,13 @@ class AnyValue : public Value {
~AnyValue() override = default; ~AnyValue() override = default;
MS_DECLARE_PARENT(AnyValue, Value) MS_DECLARE_PARENT(AnyValue, Value)
std::size_t hash() const override { return tid(); } std::size_t hash() const override { return tid(); }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
}; };
extern const ValuePtr kAnyValue; extern const ValuePtr kAnyValue;
template <> template <>
inline const char* GetValue(const ValuePtr& value) { inline const char *GetValue(const ValuePtr &value) {
if (value == nullptr) { if (value == nullptr) {
MS_LOG(EXCEPTION) << "Value is nullptr"; MS_LOG(EXCEPTION) << "Value is nullptr";
} }
@ -270,7 +270,7 @@ inline const char* GetValue(const ValuePtr& value) {
template <typename T, typename S = typename std::decay<T>::type, template <typename T, typename S = typename std::decay<T>::type,
typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type> typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type>
std::vector<U> GetValue(const ValuePtr& value) { std::vector<U> GetValue(const ValuePtr &value) {
if (value == nullptr) { if (value == nullptr) {
MS_LOG(EXCEPTION) << "Value is nullptr"; MS_LOG(EXCEPTION) << "Value is nullptr";
} }
@ -280,21 +280,21 @@ std::vector<U> GetValue(const ValuePtr& value) {
<< ">"; << ">";
} }
std::vector<U> rets; std::vector<U> rets;
const std::vector<ValuePtr>& vals = value->cast<ValueSequeuePtr>()->value(); const std::vector<ValuePtr> &vals = value->cast<ValueSequeuePtr>()->value();
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets),
[](const ValuePtr& v) { return GetValue<U>(v); }); [](const ValuePtr &v) { return GetValue<U>(v); });
return rets; return rets;
} }
inline ValueNodePtr NewValueNode(const ValuePtr& t) { return std::make_shared<ValueNode>(t); } inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared<ValueNode>(t); }
template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type> template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
inline ValueNodePtr NewValueNode(const std::shared_ptr<T>& x) { inline ValueNodePtr NewValueNode(const std::shared_ptr<T> &x) {
return NewValueNode(MakeValue(x)); return NewValueNode(MakeValue(x));
} }
template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type> template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
inline ValueNodePtr NewValueNode(const T& x) { inline ValueNodePtr NewValueNode(const T &x) {
return NewValueNode(MakeValue(x)); return NewValueNode(MakeValue(x));
} }
} // namespace mindspore } // namespace mindspore

View File

@ -22,15 +22,15 @@
#include "optimizer/opt.h" #include "optimizer/opt.h"
namespace mindspore { namespace mindspore {
using VisitFuncType = std::function<void(const AnfNodePtr&)>; using VisitFuncType = std::function<void(const AnfNodePtr &)>;
class AnfVisitor { class AnfVisitor {
public: public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr&, const AnfNodePtr&); virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &);
virtual void Visit(const AnfNodePtr&); virtual void Visit(const AnfNodePtr &);
virtual void Visit(const CNodePtr&); virtual void Visit(const CNodePtr &);
virtual void Visit(const ValueNodePtr&); virtual void Visit(const ValueNodePtr &);
virtual void Visit(const ParameterPtr&); virtual void Visit(const ParameterPtr &);
VisitFuncType Match(const PrimitivePtr&, const std::vector<opt::PredicateFuncType>& = {}); VisitFuncType Match(const PrimitivePtr &, const std::vector<opt::PredicateFuncType> & = {});
virtual ~AnfVisitor() = default; virtual ~AnfVisitor() = default;
}; };
} // namespace mindspore } // namespace mindspore

View File

@ -26,12 +26,12 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace { namespace {
void FilterInvaildKernelInfo(const CNodePtr& kernel_node, void FilterInvaildKernelInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) { std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_info_list);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
(void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
[&](const std::shared_ptr<kernel::KernelBuildInfo>& kernel_build_info) { [&](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
}); });
@ -46,7 +46,7 @@ void FilterInvaildKernelInfo(const CNodePtr& kernel_node,
} }
} }
} // namespace } // namespace
void KernelQuery(const CNodePtr& kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) { void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list); TbeMetadataInfo(kernel_node, kernel_info_list);

View File

@ -38,11 +38,11 @@ class OpAttr {
std::string value() const { return value_; } std::string value() const { return value_; }
std::string default_value() const { return default_value_; } std::string default_value() const { return default_value_; }
void set_name(const std::string& name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
void set_param_type(const std::string& param_type) { param_type_ = param_type; } void set_param_type(const std::string &param_type) { param_type_ = param_type; }
void set_type(const std::string& type) { type_ = type; } void set_type(const std::string &type) { type_ = type; }
void set_value(const std::string& value) { value_ = value; } void set_value(const std::string &value) { value_ = value; }
void set_default_value(const std::string& default_value) { default_value_ = default_value; } void set_default_value(const std::string &default_value) { default_value_ = default_value; }
private: private:
std::string name_; std::string name_;
@ -67,13 +67,13 @@ class OpIOInfo {
std::vector<std::string> formats() const { return formats_; } std::vector<std::string> formats() const { return formats_; }
void set_index(const int index) { index_ = index; } void set_index(const int index) { index_ = index; }
void set_name(const std::string& name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
void set_param_type(const std::string& param_type) { param_type_ = param_type; } void set_param_type(const std::string &param_type) { param_type_ = param_type; }
void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; } void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; }
void set_shape(const std::string& shape) { shape_ = shape; } void set_shape(const std::string &shape) { shape_ = shape; }
void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; } void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; }
void set_formats(const std::vector<std::string>& formats) { formats_ = formats; } void set_formats(const std::vector<std::string> &formats) { formats_ = formats; }
private: private:
int index_ = 0; int index_ = 0;
@ -104,24 +104,24 @@ class OpInfo {
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
const std::unordered_map<size_t, size_t>& ref_infos() const { return ref_infos_; } const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
void set_op_name(const std::string& op_name) { op_name_ = op_name; } void set_op_name(const std::string &op_name) { op_name_ = op_name; }
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
void set_impl_path(const std::string& impl_path) { impl_path_ = impl_path; } void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
void set_fusion_type(const std::string& fusion_type) { fusion_type_ = fusion_type; } void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; }
void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } void set_async_flag(const bool async_flag) { async_flag_ = async_flag; }
void set_binfile_name(const std::string& binfile_name) { binfile_name_ = binfile_name; } void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; }
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; }
void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); } void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); } void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); } void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& inputs) { inputs_ptr_ = inputs; } void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; }
void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& outputs) { outputs_ptr_ = outputs; } void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { outputs_ptr_ = outputs; }
bool is_ref() const { return !ref_infos_.empty(); } bool is_ref() const { return !ref_infos_.empty(); }
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }

View File

@ -67,7 +67,7 @@ std::string ImplTypeToStr(OpImplyType impl_type) {
return "unknow"; return "unknow";
} }
} }
bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) { bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) {
bool ret = false; bool ret = false;
try { try {
auto op_json = nlohmann::json::parse(json_string); auto op_json = nlohmann::json::parse(json_string);
@ -88,13 +88,13 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path)
if (!ret) { if (!ret) {
MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string; MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string;
} }
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(DEBUG) << "get op_json elements failed:" << e.what(); MS_LOG(DEBUG) << "get op_json elements failed:" << e.what();
} }
return ret; return ret;
} }
void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info) { void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
op_info->set_async_flag(obj.at(kAsyncFlag)); op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName)); op_info->set_binfile_name(obj.at(kBinfileName));
op_info->set_compute_cost(obj.at(kComputeCost)); op_info->set_compute_cost(obj.at(kComputeCost));
@ -108,8 +108,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_p
} }
} }
bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type, bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
const std::string& impl_path) { const std::string &impl_path) {
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>(); std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
op_info->set_op_name(obj.at(kOpName)); op_info->set_op_name(obj.at(kOpName));
@ -120,7 +120,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
DecodeTBESpecificInfo(obj, op_info); DecodeTBESpecificInfo(obj, op_info);
} }
auto attrs = obj.at(kAttr); auto attrs = obj.at(kAttr);
for (const auto& attr : attrs) { for (const auto &attr : attrs) {
if (!DecodeAttr(attr, imply_type, op_info)) { if (!DecodeAttr(attr, imply_type, op_info)) {
MS_LOG(DEBUG) << "DecodeAttr Failed"; MS_LOG(DEBUG) << "DecodeAttr Failed";
return false; return false;
@ -131,14 +131,14 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
dtype_format = obj.at(kDtypeFormat); dtype_format = obj.at(kDtypeFormat);
} }
auto inputs = obj.at(kIputs); auto inputs = obj.at(kIputs);
for (const auto& input : inputs) { for (const auto &input : inputs) {
if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
MS_LOG(DEBUG) << "DecodeInputOutput Failed"; MS_LOG(DEBUG) << "DecodeInputOutput Failed";
return false; return false;
} }
} }
auto outputs = obj.at(kOutputs); auto outputs = obj.at(kOutputs);
for (const auto& output : outputs) { for (const auto &output : outputs) {
if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
MS_LOG(DEBUG) << "DecodeInputOutput Failed"; MS_LOG(DEBUG) << "DecodeInputOutput Failed";
return false; return false;
@ -156,8 +156,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
return true; return true;
} }
bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
const std::shared_ptr<OpInfo>& op_info) { const std::shared_ptr<OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
bool ret = true; bool ret = true;
try { try {
@ -175,34 +175,34 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
op_attr->set_default_value(obj.at(kDefaultValue)); op_attr->set_default_value(obj.at(kDefaultValue));
} }
op_info->add_attrs_ptr(op_attr); op_info->add_attrs_ptr(op_attr);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what(); MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what();
ret = false; ret = false;
} }
return ret; return ret;
} }
bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io, bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
size_t index) { size_t index) {
bool ret = true; bool ret = true;
try { try {
std::vector<std::string> dtype; std::vector<std::string> dtype;
std::vector<std::string> format; std::vector<std::string> format;
for (const auto& it : dtype_format) { for (const auto &it : dtype_format) {
dtype.emplace_back(it[index][0]); dtype.emplace_back(it[index][0]);
format.emplace_back(it[index][1]); format.emplace_back(it[index][1]);
} }
op_io->set_dtypes(dtype); op_io->set_dtypes(dtype);
op_io->set_formats(format); op_io->set_formats(format);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what();
ret = false; ret = false;
} }
return ret; return ret;
} }
bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) { const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) {
bool ret = true; bool ret = true;
try { try {
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>(); std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
@ -243,14 +243,14 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply
} else if (io_type == kOutput) { } else if (io_type == kOutput) {
op_info->add_outputs_ptr(op_io); op_info->add_outputs_ptr(op_io);
} }
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what(); MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what();
ret = false; ret = false;
} }
return ret; return ret;
} }
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType imply_type) { std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->device_target() == kGPUDevice); bool is_gpu = (context->device_target() == kGPUDevice);
@ -260,7 +260,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
<< ", current op num:" << op_info_.size(); << ", current op num:" << op_info_.size();
return nullptr; return nullptr;
} }
for (const auto& op_info : op_info_) { for (const auto &op_info : op_info_) {
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) {
return op_info; return op_info;
@ -271,14 +271,14 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
return nullptr; return nullptr;
} }
bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) { bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
const auto& output_infos = op_info->outputs_ptr(); const auto &output_infos = op_info->outputs_ptr();
const auto& input_infos = op_info->inputs_ptr(); const auto &input_infos = op_info->inputs_ptr();
for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { for (size_t out_index = 0; out_index < output_infos.size(); out_index++) {
const auto& out_name = output_infos[out_index]->name(); const auto &out_name = output_infos[out_index]->name();
for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { for (size_t in_index = 0; in_index < input_infos.size(); in_index++) {
const auto& in_name = input_infos[in_index]->name(); const auto &in_name = input_infos[in_index]->name();
if (out_name == in_name) { if (out_name == in_name) {
if (op_info->has_ref_index(out_index)) { if (op_info->has_ref_index(out_index)) {
MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info"; MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info";
@ -293,9 +293,9 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) {
return true; return true;
} }
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo>& op_info) { bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
for (const auto& exist_op_info : op_info_) { for (const auto &exist_op_info : op_info_) {
MS_EXCEPTION_IF_NULL(exist_op_info); MS_EXCEPTION_IF_NULL(exist_op_info);
if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() &&
exist_op_info->impl_path() != op_info->impl_path()) { exist_op_info->impl_path() != op_info->impl_path()) {

View File

@ -28,23 +28,23 @@ class OpLib {
public: public:
OpLib() = default; OpLib() = default;
virtual ~OpLib() = default; virtual ~OpLib() = default;
bool RegOp(const std::string& json_string, const std::string& impl_path); bool RegOp(const std::string &json_string, const std::string &impl_path);
static std::shared_ptr<OpInfo> FindOp(const std::string& op_name, OpImplyType imply_type); static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
protected: protected:
static std::vector<std::shared_ptr<OpInfo>> op_info_; static std::vector<std::shared_ptr<OpInfo>> op_info_;
private: private:
static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path);
static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
const std::shared_ptr<OpInfo>& op_info); const std::shared_ptr<OpInfo> &op_info);
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io, static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
size_t index); size_t index);
static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info); static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info);
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format); const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format);
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info); static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info);
static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info); static bool CheckRepetition(const std::shared_ptr<OpInfo> &op_info);
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -19,6 +19,6 @@
namespace mindspore { namespace mindspore {
// cppcheck-suppress unusedFunction // cppcheck-suppress unusedFunction
std::string set_version(const std::string& version) { return version; } std::string set_version(const std::string &version) { return version; }
} // namespace mindspore } // namespace mindspore

View File

@ -42,11 +42,11 @@ struct OpMergedInfo {
}; };
using GenAttrFuncType = using GenAttrFuncType =
std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto*, const PrimitivePtr&)>; std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto *, const PrimitivePtr &)>;
template <typename T, size_t rep_cnt = 0> template <typename T, size_t rep_cnt = 0>
void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
auto casted_value = dyn_cast<T>(value); auto casted_value = dyn_cast<T>(value);
if (casted_value == nullptr) { if (casted_value == nullptr) {
MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
@ -76,8 +76,8 @@ void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeTy
} }
template <size_t beg_idx = 0> template <size_t beg_idx = 0>
void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
auto tuple_ptr = dyn_cast<ValueTuple>(value); auto tuple_ptr = dyn_cast<ValueTuple>(value);
if (tuple_ptr == nullptr) { if (tuple_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed.";
@ -99,8 +99,8 @@ void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_Attrib
attr_proto->set_type(attr_type); attr_proto->set_type(attr_type);
} }
void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType, void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType,
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
auto attr_value = GetValue<std::string>(value); auto attr_value = GetValue<std::string>(value);
if (attr_value == "VALID") { if (attr_value == "VALID") {
@ -112,16 +112,16 @@ void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType
class OpAttrInfo { class OpAttrInfo {
public: public:
OpAttrInfo(const std::string& attr_name, const string& onnx_attr_name, OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name,
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr)
: attr_name_(attr_name), : attr_name_(attr_name),
onnx_attr_name_(onnx_attr_name), onnx_attr_name_(onnx_attr_name),
onnx_attr_type_(onnx_attr_type), onnx_attr_type_(onnx_attr_type),
fn_gen_attr_(fn_gen_attr) {} fn_gen_attr_(fn_gen_attr) {}
~OpAttrInfo() {} ~OpAttrInfo() {}
const std::string& attr_name() const { return attr_name_; } const std::string &attr_name() const { return attr_name_; }
const std::string& onnx_attr_name() const { return onnx_attr_name_; } const std::string &onnx_attr_name() const { return onnx_attr_name_; }
onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; }
GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; }
@ -134,27 +134,27 @@ class OpAttrInfo {
class OpNameInfo { class OpNameInfo {
public: public:
OpNameInfo& set_op_type(const std::string& op_type) { OpNameInfo &set_op_type(const std::string &op_type) {
op_type_ = op_type; op_type_ = op_type;
return *this; return *this;
} }
const std::string& op_type() const { return op_type_; } const std::string &op_type() const { return op_type_; }
OpNameInfo& set_onnx_type(const std::string& onnx_type) { OpNameInfo &set_onnx_type(const std::string &onnx_type) {
onnx_type_ = onnx_type; onnx_type_ = onnx_type;
return *this; return *this;
} }
const std::string& onnx_type() const { return onnx_type_; } const std::string &onnx_type() const { return onnx_type_; }
OpNameInfo& Attr(const std::string& attr_name, const std::string& onnx_attr_name, OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name,
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) { onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) {
op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr));
return *this; return *this;
} }
const std::vector<OpAttrInfo>& op_attrs() const { return op_attrs_; } const std::vector<OpAttrInfo> &op_attrs() const { return op_attrs_; }
private: private:
std::string op_type_; // operator type of MindSpore std::string op_type_; // operator type of MindSpore
@ -183,8 +183,8 @@ OPERATOR_ONNX_CONVERT_DEFINE(
.Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>) .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>)
.Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING,
[](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto* const attr_proto, [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto,
const PrimitivePtr& prim) { const PrimitivePtr &prim) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
auto attr_value = GetValue<std::string>(value); auto attr_value = GetValue<std::string>(value);
if (attr_value == "valid") { if (attr_value == "valid") {
@ -220,7 +220,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax,
SetAttrValueToProto<Int32Imm>) SetAttrValueToProto<Int32Imm>)
.Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT,
[](ValuePtr, onnx::AttributeProto_AttributeType, [](ValuePtr, onnx::AttributeProto_AttributeType,
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
attr_proto->set_i(0); attr_proto->set_i(0);
})) }))
@ -242,7 +242,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
void RegisterOpConverters(const std::function<void(OpNameInfo&&)>& fn) { void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)());
fn(OP_CONVERT_FUNCTION_NAME(Mul)()); fn(OP_CONVERT_FUNCTION_NAME(Mul)());
@ -265,16 +265,16 @@ class OpConvertRegistry {
public: public:
~OpConvertRegistry() { Clear(); } ~OpConvertRegistry() { Clear(); }
static void RegisterOneOpConverter(OpNameInfo&& op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; }
static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); }
static OpConvertRegistry& GetSingleton() { static OpConvertRegistry &GetSingleton() {
static OpConvertRegistry registry = OpConvertRegistry(); static OpConvertRegistry registry = OpConvertRegistry();
return registry; return registry;
} }
static const std::unordered_map<std::string, OpNameInfo>& GetOpConvertMap() { return GetSingleton().op_map_; } static const std::unordered_map<std::string, OpNameInfo> &GetOpConvertMap() { return GetSingleton().op_map_; }
void Clear() noexcept { op_map_.clear(); } void Clear() noexcept { op_map_.clear(); }
@ -289,59 +289,59 @@ class OnnxExporter {
OnnxExporter() {} OnnxExporter() {}
~OnnxExporter() {} ~OnnxExporter() {}
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
private: private:
void InitModelInfo(); void InitModelInfo();
void ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto);
void ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto);
size_t ExportPrimitive(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs, const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
onnx::GraphProto* graph_proto); onnx::GraphProto *graph_proto);
static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
void SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* value_proto, bool is_output = false); void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false);
void SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* tensor_proto); void SetTensorProtoInfo(const ParameterPtr &param, onnx::TensorProto *tensor_proto);
void MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes, void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr); std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr);
void ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* graph_proto); onnx::GraphProto *graph_proto);
void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* graph_proto); onnx::GraphProto *graph_proto);
void ExportPrimReshape(const FuncGraphPtr& func_graph, const CNodePtr& node, void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimReduceMean(const FuncGraphPtr& func_graph, const CNodePtr& node, void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimCast(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* graph_proto); onnx::GraphProto *graph_proto);
void ExportPrimPReLU(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* graph_proto); onnx::GraphProto *graph_proto);
void ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* graph_proto); onnx::GraphProto *graph_proto);
void ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* graph_proto); onnx::GraphProto *graph_proto);
void ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportOutput(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* graph_proto); onnx::GraphProto *graph_proto);
std::string GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, std::string GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* const graph_proto); onnx::GraphProto *const graph_proto);
void ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* tensor_proto); void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto);
void SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* node_proto); void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto);
size_t AllocateNodeIndex() { return ++onnx_node_index_; } size_t AllocateNodeIndex() { return ++onnx_node_index_; }
void ResetNodeIndex() { onnx_node_index_ = 0; } void ResetNodeIndex() { onnx_node_index_ = 0; }
static int GetInt32Value(const AnfNodePtr& node) { static int GetInt32Value(const AnfNodePtr &node) {
auto value_node_ptr = dyn_cast<ValueNode>(node); auto value_node_ptr = dyn_cast<ValueNode>(node);
MS_EXCEPTION_IF_NULL(value_node_ptr); MS_EXCEPTION_IF_NULL(value_node_ptr);
return GetValue<int>(value_node_ptr->value()); return GetValue<int>(value_node_ptr->value());
@ -352,7 +352,7 @@ class OnnxExporter {
size_t onnx_node_index_ = 0; size_t onnx_node_index_ = 0;
}; };
std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return ""; return "";
} }
@ -360,7 +360,7 @@ std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) {
OpConvertRegistry::GetSingleton().Clear(); OpConvertRegistry::GetSingleton().Clear();
OpConvertRegistry::RegisterAllOpConverters(); OpConvertRegistry::RegisterAllOpConverters();
InitModelInfo(); InitModelInfo();
onnx::GraphProto* graph_proto = model_.mutable_graph(); onnx::GraphProto *graph_proto = model_.mutable_graph();
ExportFuncGraph(func_graph, graph_proto); ExportFuncGraph(func_graph, graph_proto);
return model_.SerializeAsString(); return model_.SerializeAsString();
} }
@ -369,11 +369,11 @@ void OnnxExporter::InitModelInfo() {
model_.set_ir_version(onnx::IR_VERSION_2019_1_22); model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
model_.set_producer_name("MindSpore"); model_.set_producer_name("MindSpore");
model_.set_producer_version("1.0"); model_.set_producer_version("1.0");
onnx::OperatorSetIdProto* opset_proto = model_.add_opset_import(); onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import();
opset_proto->set_version(9); opset_proto->set_version(9);
} }
void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
std::map<AnfNodePtr, size_t> node_map; std::map<AnfNodePtr, size_t> node_map;
onnx_node_index_ = func_graph->parameters().size(); onnx_node_index_ = func_graph->parameters().size();
@ -390,14 +390,14 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphPr
ExportNodes(func_graph, &node_map, graph_proto); ExportNodes(func_graph, &node_map, graph_proto);
} }
void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
for (auto& param : func_graph->parameters()) { for (auto &param : func_graph->parameters()) {
const ParameterPtr param_ptr = dyn_cast<Parameter>(param); const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
if (param_ptr == nullptr) { if (param_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
} }
onnx::ValueInfoProto* input_proto = graph_proto->add_input(); onnx::ValueInfoProto *input_proto = graph_proto->add_input();
input_proto->set_name(param_ptr->ToString()); input_proto->set_name(param_ptr->ToString());
SetValueInfoType(param_ptr, input_proto); SetValueInfoType(param_ptr, input_proto);
@ -405,7 +405,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphP
continue; continue;
} }
// parameter with default value is an ONNX initializer // parameter with default value is an ONNX initializer
onnx::TensorProto* initializer_proto = graph_proto->add_initializer(); onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
initializer_proto->set_name(param_ptr->ToString()); initializer_proto->set_name(param_ptr->ToString());
SetTensorProtoInfo(param_ptr, initializer_proto); SetTensorProtoInfo(param_ptr, initializer_proto);
// set value for initializer // set value for initializer
@ -445,25 +445,25 @@ onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) {
return iter->second; return iter->second;
} }
void OnnxExporter::SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* const value_proto, bool is_output) { void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) {
auto dtype = node->Type(); auto dtype = node->Type();
auto shape = node->Shape(); auto shape = node->Shape();
onnx::TypeProto* type_proto = value_proto->mutable_type(); onnx::TypeProto *type_proto = value_proto->mutable_type();
if (dtype->isa<TensorType>() && shape->isa<abstract::Shape>()) { if (dtype->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = dyn_cast<TensorType>(dtype); auto tensor = dyn_cast<TensorType>(dtype);
auto elem_type = tensor->element(); auto elem_type = tensor->element();
const auto& dims = dyn_cast<abstract::Shape>(shape)->shape(); const auto &dims = dyn_cast<abstract::Shape>(shape)->shape();
// output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64
auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id());
type_proto->mutable_tensor_type()->set_elem_type(type); type_proto->mutable_tensor_type()->set_elem_type(type);
for (const auto& dim : dims) { for (const auto &dim : dims) {
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
} }
} }
} }
void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* const tensor_proto) { void OnnxExporter::SetTensorProtoInfo(const ParameterPtr &param, onnx::TensorProto *const tensor_proto) {
auto dtype = param->Type(); auto dtype = param->Type();
auto shape = param->Shape(); auto shape = param->Shape();
if (!dtype->isa<TensorType>() || !shape->isa<abstract::Shape>()) { if (!dtype->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
@ -472,18 +472,18 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorPro
auto tensor = dyn_cast<TensorType>(dtype); auto tensor = dyn_cast<TensorType>(dtype);
auto elem_type = tensor->element(); auto elem_type = tensor->element();
const auto& dims = dyn_cast<abstract::Shape>(shape)->shape(); const auto &dims = dyn_cast<abstract::Shape>(shape)->shape();
tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id()));
for (const auto& dim : dims) { for (const auto &dim : dims) {
tensor_proto->add_dims(dim); tensor_proto->add_dims(dim);
} }
} }
void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes, void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr) { std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) {
std::unordered_map<AnfNodePtr, OpMergedInfo>& op_merged_infos = *op_merged_infos_ptr; std::unordered_map<AnfNodePtr, OpMergedInfo> &op_merged_infos = *op_merged_infos_ptr;
for (auto& node : nodes) { for (auto &node : nodes) {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
} }
@ -492,7 +492,7 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto
// if the key `input` does not exist, just create a new one // if the key `input` does not exist, just create a new one
op_merged_infos[cnode].referred_count += 1; op_merged_infos[cnode].referred_count += 1;
} }
for (auto& input : cnode->inputs()) { for (auto &input : cnode->inputs()) {
if (!input->isa<CNode>()) { if (!input->isa<CNode>()) {
continue; continue;
} }
@ -527,14 +527,14 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto
* | +-- Parameter * | +-- Parameter
* | `-- ValueNode * | `-- ValueNode
*/ */
void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* const graph_proto) { onnx::GraphProto *const graph_proto) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos; std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos;
MatchAndMark(func_graph, nodes, &op_merged_infos); MatchAndMark(func_graph, nodes, &op_merged_infos);
for (const AnfNodePtr& node : nodes) { for (const AnfNodePtr &node : nodes) {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
} }
@ -570,20 +570,20 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodeP
} }
} }
void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_shape = node->input(2); auto input_shape = node->input(2);
std::string name_shape; std::string name_shape;
if (input_shape->isa<ValueNode>()) { if (input_shape->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex(); auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[input_shape] = const_node_idx; (*node_map_ptr)[input_shape] = const_node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node(); onnx::NodeProto *node_proto = graph_proto->add_node();
name_shape = std::to_string(const_node_idx); name_shape = std::to_string(const_node_idx);
node_proto->add_output(name_shape); node_proto->add_output(name_shape);
node_proto->set_op_type("Constant"); node_proto->set_op_type("Constant");
onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value"); attr_proto->set_name("value");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
@ -595,28 +595,28 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const C
auto node_idx = AllocateNodeIndex(); auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx; (*node_map_ptr)[node] = node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node(); onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type(prim::kPrimReshape->name()); node_proto->set_op_type(prim::kPrimReshape->name());
node_proto->add_output(std::to_string(node_idx)); node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(name_x); node_proto->add_input(name_x);
node_proto->add_input(name_shape); node_proto->add_input(name_shape);
} }
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* const graph_proto) { onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_axis = node->input(2); auto input_axis = node->input(2);
auto node_idx = AllocateNodeIndex(); auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx; (*node_map_ptr)[node] = node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node(); onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type(prim::kPrimReduceMean->name()); node_proto->set_op_type(prim::kPrimReduceMean->name());
node_proto->add_output(std::to_string(node_idx)); node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data); node_proto->add_input(input_data);
if (input_axis->isa<ValueNode>()) { if (input_axis->isa<ValueNode>()) {
onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("axes"); attr_proto->set_name("axes");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
auto axis_value = dyn_cast<ValueNode>(input_axis)->value(); auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
@ -630,20 +630,20 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, cons
} }
} }
void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_type = node->input(2); auto input_type = node->input(2);
auto node_idx = AllocateNodeIndex(); auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx; (*node_map_ptr)[node] = node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node(); onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type(prim::kPrimCast->name()); node_proto->set_op_type(prim::kPrimCast->name());
node_proto->add_output(std::to_string(node_idx)); node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data); node_proto->add_input(input_data);
if (input_type->isa<ValueNode>()) { if (input_type->isa<ValueNode>()) {
onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("to"); attr_proto->set_name("to");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
auto type_value = dyn_cast<ValueNode>(input_type)->value(); auto type_value = dyn_cast<ValueNode>(input_type)->value();
@ -655,8 +655,8 @@ void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNod
} }
} }
void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
@ -668,11 +668,11 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo
// format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2]
if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) {
auto node_idx = AllocateNodeIndex(); auto node_idx = AllocateNodeIndex();
onnx::NodeProto* node_proto = graph_proto->add_node(); onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Unsqueeze"); node_proto->set_op_type("Unsqueeze");
node_proto->add_output(std::to_string(node_idx)); node_proto->add_output(std::to_string(node_idx));
onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
attr_proto->set_name("axes"); attr_proto->set_name("axes");
attr_proto->add_ints(1); attr_proto->add_ints(1);
@ -684,15 +684,15 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo
auto node_idx = AllocateNodeIndex(); auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx; (*node_map_ptr)[node] = node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node(); onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("PRelu"); node_proto->set_op_type("PRelu");
node_proto->add_output(std::to_string(node_idx)); node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_x); node_proto->add_input(input_x);
node_proto->add_input(input_slope); node_proto->add_input(input_slope);
} }
void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
if (node->IsApply(prim::kPrimReshape)) { if (node->IsApply(prim::kPrimReshape)) {
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
@ -735,31 +735,31 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& n
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
} }
size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::map<AnfNodePtr, size_t>* node_map_ptr, size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map<AnfNodePtr, size_t> *node_map_ptr,
const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs, const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
onnx::GraphProto* const graph_proto) { onnx::GraphProto *const graph_proto) {
auto op_map = OpConvertRegistry::GetOpConvertMap(); auto op_map = OpConvertRegistry::GetOpConvertMap();
auto op_iter = op_map.find(prim->name()); auto op_iter = op_map.find(prim->name());
if (op_iter == op_map.end()) { if (op_iter == op_map.end()) {
MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map";
} }
const OpNameInfo& op_convert_info = op_iter->second; const OpNameInfo &op_convert_info = op_iter->second;
auto node_idx = AllocateNodeIndex(); auto node_idx = AllocateNodeIndex();
onnx::NodeProto* node_proto = graph_proto->add_node(); onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->add_output(std::to_string(node_idx)); node_proto->add_output(std::to_string(node_idx));
node_proto->set_op_type(op_convert_info.onnx_type()); node_proto->set_op_type(op_convert_info.onnx_type());
// Set inputs // Set inputs
for (const auto& input : inputs) { for (const auto &input : inputs) {
auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto);
node_proto->add_input(input_name); node_proto->add_input(input_name);
} }
// Set node attribute // Set node attribute
for (const OpAttrInfo& attr : op_convert_info.op_attrs()) { for (const OpAttrInfo &attr : op_convert_info.op_attrs()) {
const std::string& attr_name = attr.attr_name(); const std::string &attr_name = attr.attr_name();
ValuePtr attr_value = nullptr; ValuePtr attr_value = nullptr;
if (!attr_name.empty()) { if (!attr_name.empty()) {
attr_value = prim->GetAttr(attr_name); attr_value = prim->GetAttr(attr_name);
@ -767,15 +767,15 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::ma
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name;
} }
} }
onnx::AttributeProto* onnx_attr_proto = node_proto->add_attribute(); onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
onnx_attr_proto->set_name(attr.onnx_attr_name()); onnx_attr_proto->set_name(attr.onnx_attr_name());
attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim);
} }
return node_idx; return node_idx;
} }
void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto conv_node = dyn_cast<CNode>(node->input(1)); auto conv_node = dyn_cast<CNode>(node->input(1));
auto input_x = conv_node->input(1); // conv input x auto input_x = conv_node->input(1); // conv input x
auto input_w = conv_node->input(2); // conv weight(filter) auto input_w = conv_node->input(2); // conv weight(filter)
@ -786,8 +786,8 @@ void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePt
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto);
} }
void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto matmul_node = dyn_cast<CNode>(node->input(1)); auto matmul_node = dyn_cast<CNode>(node->input(1));
auto input_x = matmul_node->input(1); // matmul input x auto input_x = matmul_node->input(1); // matmul input x
auto input_y = matmul_node->input(2); // matmul input y auto input_y = matmul_node->input(2); // matmul input y
@ -798,9 +798,9 @@ void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePt
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
} }
void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* const graph_proto) { onnx::GraphProto *const graph_proto) {
auto batch_norm_node = dyn_cast<CNode>(node->input(1)); auto batch_norm_node = dyn_cast<CNode>(node->input(1));
PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(0)))->value()); PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(0)))->value());
@ -811,20 +811,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CN
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
} }
void OnnxExporter::ExportOutput(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
if (node->inputs().size() != 2) { if (node->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
} }
AnfNodePtr arg = node->input(1); AnfNodePtr arg = node->input(1);
std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto);
onnx::ValueInfoProto* output_proto = graph_proto->add_output(); onnx::ValueInfoProto *output_proto = graph_proto->add_output();
output_proto->set_name(name); output_proto->set_name(name);
SetValueInfoType(arg, output_proto, false); SetValueInfoType(arg, output_proto, false);
} }
std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto* const graph_proto) { onnx::GraphProto *const graph_proto) {
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto iter = node_map_ptr->find(node); auto iter = node_map_ptr->find(node);
if (iter == node_map_ptr->end()) { if (iter == node_map_ptr->end()) {
@ -848,7 +848,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN
(*node_map_ptr)[node] = node_idx; (*node_map_ptr)[node] = node_idx;
std::string node_name = std::to_string(node_idx); std::string node_name = std::to_string(node_idx);
onnx::NodeProto* node_proto = graph_proto->add_node(); onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->add_output(node_name); node_proto->add_output(node_name);
SetNodeAttribute(node->cast<ValueNodePtr>()->value(), node_proto); SetNodeAttribute(node->cast<ValueNodePtr>()->value(), node_proto);
@ -859,7 +859,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN
MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name(); MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name();
} }
void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* const tensor_proto) { void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) {
auto tuple_ptr = dyn_cast<ValueTuple>(value); auto tuple_ptr = dyn_cast<ValueTuple>(value);
MS_EXCEPTION_IF_NULL(tuple_ptr); MS_EXCEPTION_IF_NULL(tuple_ptr);
if (tuple_ptr->size() == 0) { if (tuple_ptr->size() == 0) {
@ -891,14 +891,14 @@ void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto
} }
} }
void OnnxExporter::SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* const node_proto) { void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) {
node_proto->set_op_type("Constant"); node_proto->set_op_type("Constant");
onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value"); attr_proto->set_name("value");
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
} }
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
OnnxExporter exporter; OnnxExporter exporter;
return exporter.GetOnnxProtoString(func_graph); return exporter.GetOnnxProtoString(func_graph);
} }

View File

@ -32,12 +32,12 @@ enum class DataType { kInt, kFloat, kDouble, kUnknown };
// Whether has a T type data in AnyPtrList. // Whether has a T type data in AnyPtrList.
template <class T> template <class T>
bool HasType(const AnyPtrList& list) { bool HasType(const AnyPtrList &list) {
bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr& ptr) { return ptr->is<T>(); }); bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
return ret; return ret;
} }
DataType InferType(const AnyPtrList& list) { DataType InferType(const AnyPtrList &list) {
if (HasType<double>(list)) { if (HasType<double>(list)) {
return DataType::kDouble; return DataType::kDouble;
} else if (HasType<float>(list)) { } else if (HasType<float>(list)) {
@ -180,7 +180,7 @@ bool InnerScalarGe(T x, U y) {
} }
#define SCALAR_OP(op_t) \ #define SCALAR_OP(op_t) \
ValuePtr Scalar##op_t(const ValuePtrList& list) { \ ValuePtr Scalar##op_t(const ValuePtrList &list) { \
do { \ do { \
if (list.size() < 2) { \ if (list.size() < 2) { \
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
@ -223,7 +223,7 @@ SCALAR_OP(Pow)
SCALAR_OP(Floordiv) SCALAR_OP(Floordiv)
#define LOGIC_OP(op_t) \ #define LOGIC_OP(op_t) \
ValuePtr Scalar##op_t(const ValuePtrList& list) { \ ValuePtr Scalar##op_t(const ValuePtrList &list) { \
if (list.size() < 2) { \ if (list.size() < 2) { \
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
} \ } \
@ -274,7 +274,7 @@ LOGIC_OP(Ne)
LOGIC_OP(Le) LOGIC_OP(Le)
LOGIC_OP(Ge) LOGIC_OP(Ge)
ValuePtr ScalarUAdd(const ValuePtrList& list) { ValuePtr ScalarUAdd(const ValuePtrList &list) {
if (list.size() != 1) { if (list.size() != 1) {
MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size();
} }
@ -283,7 +283,7 @@ ValuePtr ScalarUAdd(const ValuePtrList& list) {
return x; return x;
} }
ValuePtr ScalarUSub(const ValuePtrList& list) { ValuePtr ScalarUSub(const ValuePtrList &list) {
if (list.size() != 1) { if (list.size() != 1) {
MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size();
} }
@ -302,7 +302,7 @@ ValuePtr ScalarUSub(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << ".";
} }
ValuePtr ScalarLog(const ValuePtrList& list) { ValuePtr ScalarLog(const ValuePtrList &list) {
if (list.empty()) { if (list.empty()) {
MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty.";
} }
@ -321,7 +321,7 @@ ValuePtr ScalarLog(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString();
} }
ValuePtr BoolNot(const ValuePtrList& list) { ValuePtr BoolNot(const ValuePtrList &list) {
if (list.empty()) { if (list.empty()) {
MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; MS_LOG(EXCEPTION) << "value list of BoolNot is empty";
} }
@ -337,7 +337,7 @@ ValuePtr BoolNot(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString();
} }
ValuePtr BoolAnd(const ValuePtrList& list) { ValuePtr BoolAnd(const ValuePtrList &list) {
if (list.size() < 2) { if (list.size() < 2) {
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2.";
} }
@ -356,7 +356,7 @@ ValuePtr BoolAnd(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << ".";
} }
ValuePtr BoolOr(const ValuePtrList& list) { ValuePtr BoolOr(const ValuePtrList &list) {
if (list.size() < 2) { if (list.size() < 2) {
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2.";
} }
@ -375,7 +375,7 @@ ValuePtr BoolOr(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << ".";
} }
ValuePtr BoolEq(const ValuePtrList& list) { ValuePtr BoolEq(const ValuePtrList &list) {
if (list.size() < 2) { if (list.size() < 2) {
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2.";
} }

View File

@ -29,29 +29,29 @@ namespace prim {
using Any = mindspore::Any; using Any = mindspore::Any;
using AnyPtrList = std::vector<std::shared_ptr<Any>>; using AnyPtrList = std::vector<std::shared_ptr<Any>>;
using ValuePtrList = std::vector<ValuePtr>; using ValuePtrList = std::vector<ValuePtr>;
using OpsFunction = std::function<Any(const AnyPtrList&)>; using OpsFunction = std::function<Any(const AnyPtrList &)>;
using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr>&)>; using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr> &)>;
ValuePtr ScalarAdd(const ValuePtrList& list); ValuePtr ScalarAdd(const ValuePtrList &list);
ValuePtr ScalarSub(const ValuePtrList& list); ValuePtr ScalarSub(const ValuePtrList &list);
ValuePtr ScalarMul(const ValuePtrList& list); ValuePtr ScalarMul(const ValuePtrList &list);
ValuePtr ScalarDiv(const ValuePtrList& list); ValuePtr ScalarDiv(const ValuePtrList &list);
ValuePtr ScalarMod(const ValuePtrList& list); ValuePtr ScalarMod(const ValuePtrList &list);
ValuePtr ScalarPow(const ValuePtrList& list); ValuePtr ScalarPow(const ValuePtrList &list);
ValuePtr ScalarFloordiv(const ValuePtrList& list); ValuePtr ScalarFloordiv(const ValuePtrList &list);
ValuePtr ScalarUAdd(const ValuePtrList& list); ValuePtr ScalarUAdd(const ValuePtrList &list);
ValuePtr ScalarUSub(const ValuePtrList& list); ValuePtr ScalarUSub(const ValuePtrList &list);
ValuePtr ScalarLog(const ValuePtrList& list); ValuePtr ScalarLog(const ValuePtrList &list);
ValuePtr ScalarEq(const ValuePtrList& list); ValuePtr ScalarEq(const ValuePtrList &list);
ValuePtr ScalarLt(const ValuePtrList& list); ValuePtr ScalarLt(const ValuePtrList &list);
ValuePtr ScalarGt(const ValuePtrList& list); ValuePtr ScalarGt(const ValuePtrList &list);
ValuePtr ScalarNe(const ValuePtrList& list); ValuePtr ScalarNe(const ValuePtrList &list);
ValuePtr ScalarLe(const ValuePtrList& list); ValuePtr ScalarLe(const ValuePtrList &list);
ValuePtr ScalarGe(const ValuePtrList& list); ValuePtr ScalarGe(const ValuePtrList &list);
ValuePtr BoolNot(const ValuePtrList& list); ValuePtr BoolNot(const ValuePtrList &list);
ValuePtr BoolAnd(const ValuePtrList& list); ValuePtr BoolAnd(const ValuePtrList &list);
ValuePtr BoolOr(const ValuePtrList& list); ValuePtr BoolOr(const ValuePtrList &list);
ValuePtr BoolEq(const ValuePtrList& list); ValuePtr BoolEq(const ValuePtrList &list);
std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2); std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2);
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -66,7 +66,7 @@ const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail");
// Apply a function of two arguments cumulatively to the items of a sequence, // Apply a function of two arguments cumulatively to the items of a sequence,
// from left to right, so as to reduce the sequence to a single value.For example, // from left to right, so as to reduce the sequence to a single value.For example,
// reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) {
std::shared_ptr<Any> ret; std::shared_ptr<Any> ret;
size_t size = list.size(); size_t size = list.size();
if (size < 2) { if (size < 2) {
@ -88,7 +88,7 @@ AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) {
return ret; return ret;
} }
AnfNodePtr Reduce(const AnfNodeOpsFunction& func, const std::vector<AnfNodePtr>& list) { AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) {
size_t size = list.size(); size_t size = list.size();
if (size < 2) { if (size < 2) {
MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
@ -121,7 +121,7 @@ void HyperMap::Init() {
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
} }
HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf) HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
: MetaFuncGraph("hyper_map"), : MetaFuncGraph("hyper_map"),
fn_leaf_(fn_leaf), fn_leaf_(fn_leaf),
broadcast_(false), broadcast_(false),
@ -129,13 +129,13 @@ HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf)
Init(); Init();
} }
HyperMap::HyperMap(const HyperMap& h) HyperMap::HyperMap(const HyperMap &h)
: MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
Init(); Init();
} }
AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList& arg_map) { const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
if (fn_arg != nullptr) { if (fn_arg != nullptr) {
@ -145,17 +145,17 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const Anf
} }
(void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
[](const std::pair<AnfNodePtr, Any>& item) { return item.first; }); [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
return func_graph->NewCNode(inputs); return func_graph->NewCNode(inputs);
} }
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph, AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size(); std::size_t size = type->elements().size();
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) { bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
auto lhs = std::static_pointer_cast<List>(item.second); auto lhs = std::static_pointer_cast<List>(item.second);
MS_EXCEPTION_IF_NULL(lhs); MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size; return lhs->elements().size() != size;
@ -179,7 +179,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph
(void)std::transform( (void)std::transform(
arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
[&func_graph, i](const std::pair<AnfNodePtr, Any>& item) { [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
}); });
@ -188,13 +188,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph
return func_graph->NewCNode(inputs); return func_graph->NewCNode(inputs);
} }
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph, AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size(); std::size_t size = type->elements().size();
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) { bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
auto lhs = std::static_pointer_cast<Tuple>(item.second); auto lhs = std::static_pointer_cast<Tuple>(item.second);
MS_EXCEPTION_IF_NULL(lhs); MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size; return lhs->elements().size() != size;
@ -226,8 +226,8 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGrap
return func_graph->NewCNode(inputs); return func_graph->NewCNode(inputs);
} }
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph, AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -257,11 +257,11 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGrap
return func_graph->NewCNode(inputs); return func_graph->NewCNode(inputs);
} }
AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
bool found = false; bool found = false;
TypeId id = kObjectTypeEnd; TypeId id = kObjectTypeEnd;
std::pair<AnfNodePtr, TypePtr> pair; std::pair<AnfNodePtr, TypePtr> pair;
for (auto& item : arg_map) { for (auto &item : arg_map) {
pair = item; pair = item;
id = item.second->type_id(); id = item.second->type_id();
if (nonleaf_.count(id)) { if (nonleaf_.count(id)) {
@ -272,7 +272,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
if (found) { if (found) {
// In a nonleaf situation, all arguments must have the same generic. // In a nonleaf situation, all arguments must have the same generic.
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr>& item) { bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
if (item.first != pair.first) { if (item.first != pair.first) {
return item.second->type_id() != pair.second->type_id(); return item.second->type_id() != pair.second->type_id();
} }
@ -283,7 +283,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
<< trace::GetDebugInfo(func_graph->debug_info()) << "\n"; << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
int idx = 0; int idx = 0;
for (auto& item : arg_map) { for (auto &item : arg_map) {
oss << ++idx << ": " << item.second->ToString() << "\n"; oss << ++idx << ": " << item.second->ToString() << "\n";
} }
MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
@ -308,14 +308,14 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
} }
} }
ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairList& args_spec_list) { ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
TypePtr type_tensor = std::make_shared<TensorType>(); TypePtr type_tensor = std::make_shared<TensorType>();
bool flag = std::any_of( bool flag = std::any_of(
args_spec_list.begin(), args_spec_list.end(), args_spec_list.begin(), args_spec_list.end(),
[type_tensor](const std::pair<AnfNodePtr, TypePtr>& item) { return IsSubType(item.second, type_tensor); }); [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
if (flag && broadcast_) { if (flag && broadcast_) {
ArgsPairList ret; ArgsPairList ret;
for (auto& item : args_spec_list) { for (auto &item : args_spec_list) {
if (!IsSubType(item.second, type_tensor)) { if (!IsSubType(item.second, type_tensor)) {
TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second); TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
ret.push_back( ret.push_back(
@ -329,7 +329,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairL
return args_spec_list; return args_spec_list;
} }
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->debug_info()->set_name("hyper_map"); ptrGraph->debug_info()->set_name("hyper_map");
@ -353,7 +353,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) {
return ptrGraph; return ptrGraph;
} }
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& args_spec_list) const { abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
if (fn_leaf_ == nullptr) { if (fn_leaf_ == nullptr) {
MS_EXCEPTION_IF_NULL(args_spec_list[0]); MS_EXCEPTION_IF_NULL(args_spec_list[0]);
// Assert that hypermap's function param does not contain free variables // Assert that hypermap's function param does not contain free variables
@ -368,20 +368,20 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList&
AbstractBasePtrList broadened; AbstractBasePtrList broadened;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
[](const AbstractBasePtr& arg) -> AbstractBasePtr { [](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
return arg->Broaden(); return arg->Broaden();
}); });
return broadened; return broadened;
} }
REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
(void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_") (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf")) .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
.def(py::init<>()); .def(py::init<>());
})); }));
FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple) { FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) {
MS_EXCEPTION_IF_NULL(a_tuple); MS_EXCEPTION_IF_NULL(a_tuple);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
@ -401,7 +401,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tu
return ret; return ret;
} }
FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list) { FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) {
MS_EXCEPTION_IF_NULL(a_list); MS_EXCEPTION_IF_NULL(a_list);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
@ -421,7 +421,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list
return ret; return ret;
} }
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() != 1) { if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
} }
@ -441,11 +441,11 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list)
} }
REGISTER_PYBIND_DEFINE( REGISTER_PYBIND_DEFINE(
Tail_, ([](const py::module* m) { Tail_, ([](const py::module *m) {
(void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string&>()); (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
})); }));
FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
int tuple_size = SizeToInt(args_spec_list.size()); int tuple_size = SizeToInt(args_spec_list.size());
std::ostringstream ss; std::ostringstream ss;
@ -486,7 +486,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& arg
return fg; return fg;
} }
GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_list, bool sens_param) GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
: MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
if (get_by_list) { if (get_by_list) {
signatures_ = signatures_ =
@ -496,8 +496,8 @@ GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_
} }
} }
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
const std::vector<AnfNodePtr>& params_list, bool applyJ) { const std::vector<AnfNodePtr> &params_list, bool applyJ) {
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
@ -537,7 +537,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights,
return ret; return ret;
} }
void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem) { ValueNodePtr opsTupleItem) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -590,7 +590,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, An
} }
// Generate the graph. // Generate the graph.
FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() < 1) { if (args_spec_list.size() < 1) {
MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
<< args_spec_list.size() << "."; << args_spec_list.size() << ".";
@ -637,21 +637,21 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_sp
return dfBuilder; return dfBuilder;
} }
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
(void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>( (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
*m, "GradOperation_") *m, "GradOperation_")
.def(py::init<std::string&>(), py::arg("fn")) .def(py::init<std::string &>(), py::arg("fn"))
.def(py::init<std::string&, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"), .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
py::arg("get_by_list"), py::arg("sens_param")); py::arg("get_by_list"), py::arg("sens_param"));
})); }));
MultitypeFuncGraph::MultitypeFuncGraph(const std::string& name) : MetaFuncGraph(name) { MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
fn_cache_.clear(); fn_cache_.clear();
signatures_ = std::vector<Signature>({// def multitype(*args:ref): signatures_ = std::vector<Signature>({// def multitype(*args:ref):
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
} }
void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) { void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
auto fn = fn_cache_.find(types); auto fn = fn_cache_.find(types);
if (fn != fn_cache_.end()) { if (fn != fn_cache_.end()) {
@ -660,7 +660,7 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn)
fn_cache_[types] = s_fn; fn_cache_[types] = s_fn;
} }
void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& py_fn) { void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
auto fn = fn_cache_.find(types); auto fn = fn_cache_.find(types);
if (fn != fn_cache_.end()) { if (fn != fn_cache_.end()) {
@ -669,9 +669,9 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function&
fn_cache_py_[types] = py_fn; fn_cache_py_[types] = py_fn;
} }
void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, const py::function& py_fn) { void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) {
TypePtrList types; TypePtrList types;
for (auto& type_name : types_name) { for (auto &type_name : types_name) {
auto type_ptr = StringToType(type_name); auto type_ptr = StringToType(type_name);
if (type_ptr == nullptr) { if (type_ptr == nullptr) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error "; MS_LOG(EXCEPTION) << "" << type_name << " convert from string error ";
@ -681,7 +681,7 @@ void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, co
Register(types, py_fn); Register(types, py_fn);
} }
void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& py_fn) { void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
std::vector<std::string> types_name; std::vector<std::string> types_name;
for (size_t it = 0; it < tuple.size(); ++it) { for (size_t it = 0; it < tuple.size(); ++it) {
py::object name_py = tuple[it]; py::object name_py = tuple[it];
@ -693,16 +693,16 @@ void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function&
} }
Register(types_name, py_fn); Register(types_name, py_fn);
} }
static TypePtr UnwrapRef(const TypePtr& type) { static TypePtr UnwrapRef(const TypePtr &type) {
if (type->isa<RefType>()) { if (type->isa<RefType>()) {
return type->cast<RefTypePtr>()->subtype(); return type->cast<RefTypePtr>()->subtype();
} }
return type; return type;
} }
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
bool find_fn = false; bool find_fn = false;
py::function py_fn; py::function py_fn;
for (auto& item : fn_cache_py_) { for (auto &item : fn_cache_py_) {
TypePtrList sign = item.first; TypePtrList sign = item.first;
if (sign.size() != types.size()) { if (sign.size() != types.size()) {
continue; continue;
@ -735,7 +735,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) {
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
<< "`, corresponding location info:\n"; << "`, corresponding location info:\n";
int idx = 0; int idx = 0;
for (auto& item : fn_cache_py_) { for (auto &item : fn_cache_py_) {
FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
if (func_graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
@ -747,15 +747,15 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) {
<< oss.str(); << oss.str();
} }
REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
(void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>( (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
*m, "MultitypeFuncGraph_") *m, "MultitypeFuncGraph_")
.def(py::init<std::string&>()) .def(py::init<std::string &>())
.def("register_fn", &MultitypeFuncGraph::PyRegister); .def("register_fn", &MultitypeFuncGraph::PyRegister);
})); }));
// Generate the ListMap func graph. // Generate the ListMap func graph.
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
size_t args_num = args_spec_list.size(); size_t args_num = args_spec_list.size();
// args: fn, list1, list2, ... // args: fn, list1, list2, ...
if (args_num < 2) { if (args_num < 2) {
@ -821,8 +821,8 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_lis
return fg_ptr; return fg_ptr;
} }
void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgnext_ptr, void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
const FuncGraphPtr& fg_ptr) { const FuncGraphPtr &fg_ptr) {
MS_EXCEPTION_IF_NULL(fg_ptr); MS_EXCEPTION_IF_NULL(fg_ptr);
AnfNodePtr fn = fg_ptr->add_parameter(); AnfNodePtr fn = fg_ptr->add_parameter();
@ -858,8 +858,8 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr&
fgtrue_ptr->set_output(output_cnode); fgtrue_ptr->set_output(output_cnode);
} }
void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgcond_ptr, void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
const FuncGraphPtr& fg_ptr) { const FuncGraphPtr &fg_ptr) {
MS_EXCEPTION_IF_NULL(fg_ptr); MS_EXCEPTION_IF_NULL(fg_ptr);
AnfNodePtr fn = fg_ptr->add_parameter(); AnfNodePtr fn = fg_ptr->add_parameter();
@ -893,7 +893,7 @@ void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr&
fg_ptr->set_output(output_cnode); fg_ptr->set_output(output_cnode);
} }
FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// args: tuple1, tuple2 // args: tuple1, tuple2
abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
AbstractBasePtr abs_a = args_spec_list[0]; AbstractBasePtr abs_a = args_spec_list[0];
@ -928,7 +928,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_li
return ret; return ret;
} }
int GetArgScalarValue(const abstract::AbstractScalarPtr& scalar, const std::string&) { int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
MS_EXCEPTION_IF_NULL(scalar); MS_EXCEPTION_IF_NULL(scalar);
return GetValue<int>(scalar->BuildValue()); return GetValue<int>(scalar->BuildValue());
} }
@ -942,7 +942,7 @@ int GetPositiveIndex(int index, int length) {
return index; return index;
} }
int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std::string& member_name) { int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) {
MS_EXCEPTION_IF_NULL(member); MS_EXCEPTION_IF_NULL(member);
if (member->isa<AbstractScalar>()) { if (member->isa<AbstractScalar>()) {
@ -957,8 +957,8 @@ int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std
<< member->ToString(); << member->ToString();
} }
void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSlicePtr& slice, int* start_index, void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index,
int* stop_index, int* step_value) { int *stop_index, int *step_value) {
MS_EXCEPTION_IF_NULL(tuple); MS_EXCEPTION_IF_NULL(tuple);
MS_EXCEPTION_IF_NULL(slice); MS_EXCEPTION_IF_NULL(slice);
MS_EXCEPTION_IF_NULL(start_index); MS_EXCEPTION_IF_NULL(start_index);
@ -998,7 +998,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSl
} }
} }
FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tuple // slice a tuple
// args: tuple, start index, end index, step // args: tuple, start index, end index, step
const std::string op_name("TupleSlice"); const std::string op_name("TupleSlice");
@ -1032,7 +1032,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
return ret; return ret;
} }
int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) { int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) {
unsigned int number_dec = 0; unsigned int number_dec = 0;
for (size_t index = 0; index < number_bin.size(); index++) { for (size_t index = 0; index < number_bin.size(); index++) {
number_dec |= number_bin[index] << index; number_dec |= number_bin[index] << index;
@ -1040,8 +1040,8 @@ int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) {
return static_cast<int>(number_dec); return static_cast<int>(number_dec);
} }
void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vector<int>* end, void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end,
std::vector<int>* strides, int length) { std::vector<int> *strides, int length) {
MS_EXCEPTION_IF_NULL(slice); MS_EXCEPTION_IF_NULL(slice);
MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(end);
@ -1064,8 +1064,8 @@ void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vec
strides->push_back(step_value); strides->push_back(step_value);
} }
int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, const std::vector<int>& shape, int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape,
std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) { std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(slice_tuple); MS_EXCEPTION_IF_NULL(slice_tuple);
MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(end);
@ -1111,8 +1111,8 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple,
return ConvertBinaryToDecimal(shrink); return ConvertBinaryToDecimal(shrink);
} }
int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const std::vector<int>& shape, int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape,
std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) { std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(end);
MS_EXCEPTION_IF_NULL(strides); MS_EXCEPTION_IF_NULL(strides);
@ -1132,9 +1132,9 @@ int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const
return 0; return 0;
} }
int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, const std::vector<int>& shape, int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape,
std::vector<int>* begin, std::vector<int>* end, std::vector<int> *begin, std::vector<int> *end,
std::vector<int>* strides) { std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(end);
MS_EXCEPTION_IF_NULL(strides); MS_EXCEPTION_IF_NULL(strides);
@ -1153,7 +1153,7 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, co
return 1; return 1;
} }
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tensor // slice a tensor
// args: tensor, slice or slice tuple // args: tensor, slice or slice tuple
const std::string op_name = std::string("TensorSlice"); const std::string op_name = std::string("TensorSlice");
@ -1177,7 +1177,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
} else { } else {
std::ostringstream args_info; std::ostringstream args_info;
for (const auto& arg : args_spec_list) { for (const auto &arg : args_spec_list) {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
args_info << arg->ToString() << "\n"; args_info << arg->ToString() << "\n";
} }
@ -1199,19 +1199,19 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec
return ret_graph; return ret_graph;
} }
REGISTER_PYBIND_DEFINE( REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
TupleAdd_, ([](const py::module* m) { (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>()); .def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module* m) {
(void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
.def(py::init<std::string&>());
})); }));
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
(void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
.def(py::init<std::string&>()); .def(py::init<std::string &>());
})); }));
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -47,20 +47,20 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
class MultitypeFuncGraph : public MetaFuncGraph { class MultitypeFuncGraph : public MetaFuncGraph {
public: public:
explicit MultitypeFuncGraph(const std::string& name); explicit MultitypeFuncGraph(const std::string &name);
~MultitypeFuncGraph() override = default; ~MultitypeFuncGraph() override = default;
MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph)
using specialize_fn = FuncGraph* (*)(TypePtrList); using specialize_fn = FuncGraph *(*)(TypePtrList);
// Register a method which specialize based on types vectors; // Register a method which specialize based on types vectors;
virtual void Register(const TypePtrList& types, specialize_fn s_fn); virtual void Register(const TypePtrList &types, specialize_fn s_fn);
virtual void Register(const TypePtrList& types, const py::function& py_fn); virtual void Register(const TypePtrList &types, const py::function &py_fn);
virtual void Register(const std::vector<std::string>& types_name, const py::function& py_fn); virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn);
virtual void PyRegister(const py::tuple& tuple, const py::function& py_fn); virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn);
FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override; FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override;
size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); }
const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual>& GetPyFunctions() const { const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const {
return fn_cache_py_; return fn_cache_py_;
} }
@ -72,10 +72,10 @@ using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>;
class HyperMap : public MetaFuncGraph { class HyperMap : public MetaFuncGraph {
public: public:
explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr); explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
HyperMap(const HyperMap& h); HyperMap(const HyperMap &h);
void Init(); void Init();
HyperMap& operator=(const HyperMap& h) { HyperMap &operator=(const HyperMap &h) {
if (this != &h) { if (this != &h) {
fn_leaf_ = h.fn_leaf_; fn_leaf_ = h.fn_leaf_;
broadcast_ = h.broadcast_; broadcast_ = h.broadcast_;
@ -89,21 +89,21 @@ class HyperMap : public MetaFuncGraph {
~HyperMap() override = default; ~HyperMap() override = default;
MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const override; abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
FuncGraphPtr GenerateFromTypes(const TypePtrList& args_spec_list) override; FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
private: private:
AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList& arg_map); const ArgsPairList &arg_map);
AnfNodePtr FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList& arg_map); const ArgsPairList &arg_map);
AnfNodePtr FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList& arg_map); const ArgsPairList &arg_map);
AnfNodePtr FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList& arg_map); const ArgsPairList &arg_map);
AnfNodePtr Make(const FuncGraphPtr& graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map); AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
ArgsPairList Harmonize(const FuncGraphPtr& graph, const ArgsPairList& args_spec_list); ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
MultitypeFuncGraphPtr fn_leaf_; MultitypeFuncGraphPtr fn_leaf_;
bool broadcast_; bool broadcast_;
@ -113,7 +113,7 @@ using HyperMapPtr = std::shared_ptr<HyperMap>;
class HyperMapPy : public HyperMap { class HyperMapPy : public HyperMap {
public: public:
explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr) : HyperMap(fn_leaf) {} explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {}
~HyperMapPy() override = default; ~HyperMapPy() override = default;
MS_DECLARE_PARENT(HyperMapPy, HyperMap) MS_DECLARE_PARENT(HyperMapPy, HyperMap)
}; };
@ -123,56 +123,56 @@ extern ValuePtr kCompositeHyperMap;
class Tail : public MetaFuncGraph { class Tail : public MetaFuncGraph {
public: public:
explicit Tail(const std::string& name) : MetaFuncGraph(name) {} explicit Tail(const std::string &name) : MetaFuncGraph(name) {}
~Tail() override = default; ~Tail() override = default;
MS_DECLARE_PARENT(Tail, MetaFuncGraph) MS_DECLARE_PARENT(Tail, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple); FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple);
FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr& a_list); FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list);
friend bool operator==(const Tail& lhs, const Tail& rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
}; };
using TailPtr = std::shared_ptr<Tail>; using TailPtr = std::shared_ptr<Tail>;
class MakeTupleGradient : public MetaFuncGraph { class MakeTupleGradient : public MetaFuncGraph {
public: public:
explicit MakeTupleGradient(const std::string& name) : MetaFuncGraph(name) {} explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
~MakeTupleGradient() override = default; ~MakeTupleGradient() override = default;
MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const MakeTupleGradient& lhs, const MakeTupleGradient& rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
}; };
using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>; using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
class GradOperation : public MetaFuncGraph { class GradOperation : public MetaFuncGraph {
public: public:
explicit GradOperation(const std::string& name, bool get_all = false, bool get_by_list = false, explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
bool sens_param = false); bool sens_param = false);
~GradOperation() override = default; ~GradOperation() override = default;
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams, FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
bool applyJ = false); bool applyJ = false);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
bool sens_param() const { return sens_param_; } bool sens_param() const { return sens_param_; }
bool get_all_; bool get_all_;
bool get_by_list_; bool get_by_list_;
bool sens_param_; bool sens_param_;
private: private:
void doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem); ValueNodePtr opsTupleItem);
}; };
using GradOperationPtr = std::shared_ptr<GradOperation>; using GradOperationPtr = std::shared_ptr<GradOperation>;
class ListMap { class ListMap {
public: public:
explicit ListMap(const std::string& name) : name_(name) { cache_.clear(); } explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
~ListMap() = default; ~ListMap() = default;
void MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gnext_ptr, const FuncGraphPtr& graph_ptr); void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr);
void MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gcond_ptr, const FuncGraphPtr& graph_ptr); void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list); FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list);
private: private:
std::string name_; std::string name_;
@ -181,31 +181,31 @@ class ListMap {
class TupleAdd : public MetaFuncGraph { class TupleAdd : public MetaFuncGraph {
public: public:
explicit TupleAdd(const std::string& name) : MetaFuncGraph(name) {} explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
~TupleAdd() override = default; ~TupleAdd() override = default;
MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TupleAdd& lhs, const TupleAdd& rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
}; };
using TupleAddPtr = std::shared_ptr<TupleAdd>; using TupleAddPtr = std::shared_ptr<TupleAdd>;
class TupleSlice : public MetaFuncGraph { class TupleSlice : public MetaFuncGraph {
public: public:
explicit TupleSlice(const std::string& name) : MetaFuncGraph(name) {} explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
~TupleSlice() override = default; ~TupleSlice() override = default;
MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TupleSlice& lhs, const TupleSlice& rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
}; };
using TupleSlicePtr = std::shared_ptr<TupleSlice>; using TupleSlicePtr = std::shared_ptr<TupleSlice>;
class TensorSlice : public MetaFuncGraph { class TensorSlice : public MetaFuncGraph {
public: public:
explicit TensorSlice(const std::string& name) : MetaFuncGraph(name) {} explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {}
~TensorSlice() override = default; ~TensorSlice() override = default;
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TensorSlice& lhs, const TensorSlice& rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
}; };
using TensorSlicePtr = std::shared_ptr<TensorSlice>; using TensorSlicePtr = std::shared_ptr<TensorSlice>;

View File

@ -34,7 +34,7 @@ namespace prim {
namespace { namespace {
using PatternListType = std::initializer_list<BaseRef>; using PatternListType = std::initializer_list<BaseRef>;
const std::vector<Signature>& GetSignature(const ValuePtr& function) { const std::vector<Signature> &GetSignature(const ValuePtr &function) {
static const auto empty = std::vector<Signature>(); static const auto empty = std::vector<Signature>();
if (function->isa<Primitive>()) { if (function->isa<Primitive>()) {
return function->cast<PrimitivePtr>()->signatures(); return function->cast<PrimitivePtr>()->signatures();
@ -44,8 +44,8 @@ const std::vector<Signature>& GetSignature(const ValuePtr& function) {
return empty; return empty;
} }
void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& args_spec_list, void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
const std::vector<Signature>& signature, bool has_var, std::vector<AnfNodePtr>* op_inputs) { const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *op_inputs) {
std::size_t sig_size = signature.size(); std::size_t sig_size = signature.size();
auto positional_size = sig_size; auto positional_size = sig_size;
if (has_var) { if (has_var) {
@ -64,8 +64,8 @@ void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& arg
} }
// Get the largest type of index in the same SignatureEnumDType of arguments. // Get the largest type of index in the same SignatureEnumDType of arguments.
std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType>& dtypes, std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType> &dtypes,
const abstract::AbstractBasePtrList& args_spec_list) { const abstract::AbstractBasePtrList &args_spec_list) {
// record index for signature.dtypes of the same type // record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
@ -89,7 +89,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
continue; continue;
} }
for (const auto& index : indexs) { for (const auto &index : indexs) {
AbstractBasePtr arg_value = args_spec_list[index]; AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
@ -104,7 +104,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
return dst_type; return dst_type;
} }
AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const FuncGraphPtr& graph) { AnfNodePtr DoCast(const AnfNodePtr &param, const AnfNodePtr &source_param, const FuncGraphPtr &graph) {
// op and module import path // op and module import path
auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional"); auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional");
MS_EXCEPTION_IF_NULL(prim_dtype); MS_EXCEPTION_IF_NULL(prim_dtype);
@ -116,11 +116,11 @@ AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const
return NewCNode({cast_node, param, dtype_node}, graph); return NewCNode({cast_node, param, dtype_node}, graph);
} }
void DoAutoCast(const std::vector<Signature>& signature, const abstract::AbstractBasePtrList& args_spec_list, void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list,
const FuncGraphPtr& graph, std::vector<AnfNodePtr>* op_inputs) { const FuncGraphPtr &graph, std::vector<AnfNodePtr> *op_inputs) {
std::vector<SignatureEnumDType> dtypes; std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature& sig) { return sig.dtype; }); [](const Signature &sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) { if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return; return;
@ -143,10 +143,10 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac
} }
} }
AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList& args_spec_list, const std::vector<AnfNodePtr>& params_list) { const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> &params_list) {
// args: original inputs // args: original inputs
auto& signature = GetSignature(function); auto &signature = GetSignature(function);
std::size_t sig_size = signature.size(); std::size_t sig_size = signature.size();
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
if (sig_size > 0) { if (sig_size > 0) {
@ -196,13 +196,13 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func
} }
} // namespace } // namespace
AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs) { const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) {
auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs);
return new_cnode; return new_cnode;
} }
FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
for (size_t i = 0; i < args_spec_list.size(); ++i) { for (size_t i = 0; i < args_spec_list.size(); ++i) {

View File

@ -37,17 +37,17 @@ namespace mindspore {
namespace prim { namespace prim {
class DoSignatureMetaFuncGraph : public MetaFuncGraph { class DoSignatureMetaFuncGraph : public MetaFuncGraph {
public: public:
explicit DoSignatureMetaFuncGraph(const std::string& name, const ValuePtr& function) explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function)
: MetaFuncGraph("S-" + name), function_(function) {} : MetaFuncGraph("S-" + name), function_(function) {}
~DoSignatureMetaFuncGraph() override = default; ~DoSignatureMetaFuncGraph() override = default;
MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override;
const ValuePtr function() const { return function_; } const ValuePtr function() const { return function_; }
friend bool operator==(const DoSignatureMetaFuncGraph& lhs, const DoSignatureMetaFuncGraph& rhs) { friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) {
return &lhs == &rhs; return &lhs == &rhs;
} }
@ -56,8 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph {
}; };
using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>; using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>;
AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs); const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs);
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
// namespace to support composite operators definition // namespace to support composite operators definition
namespace prim { namespace prim {
FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& args_list) { FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
abstract::CheckArgsSize("ListAppend", args_list, 2); abstract::CheckArgsSize("ListAppend", args_list, 2);
AbstractBasePtr arg0 = args_list[0]; AbstractBasePtr arg0 = args_list[0];
@ -52,9 +52,9 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList&
return ret; return ret;
} }
REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) {
(void)py::class_<ListAppend, MetaFuncGraph, std::shared_ptr<ListAppend>>(*m, "ListAppend_") (void)py::class_<ListAppend, MetaFuncGraph, std::shared_ptr<ListAppend>>(*m, "ListAppend_")
.def(py::init<std::string&>()); .def(py::init<std::string &>());
})); }));
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -28,15 +28,15 @@ namespace mindspore {
namespace prim { namespace prim {
class ListAppend : public MetaFuncGraph { class ListAppend : public MetaFuncGraph {
public: public:
explicit ListAppend(const std::string& name) : MetaFuncGraph(name) {} explicit ListAppend(const std::string &name) : MetaFuncGraph(name) {}
~ListAppend() override = default; ~ListAppend() override = default;
MS_DECLARE_PARENT(ListAppend, MetaFuncGraph) MS_DECLARE_PARENT(ListAppend, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& a_list) override; FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream& operator<<(std::ostream& os, const ListAppend& list_append) { friend std::ostream &operator<<(std::ostream &os, const ListAppend &list_append) {
os << list_append.name_; os << list_append.name_;
return os; return os;
} }
friend bool operator==(const ListAppend& lhs, const ListAppend& rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const ListAppend &lhs, const ListAppend &rhs) { return lhs.name_ == rhs.name_; }
}; };
using ListAppendPtr = std::shared_ptr<ListAppend>; using ListAppendPtr = std::shared_ptr<ListAppend>;
} // namespace prim } // namespace prim

View File

@ -40,7 +40,7 @@ using mindspore::abstract::AbstractKeywordArg;
using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr; using mindspore::abstract::AbstractTuplePtr;
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tensor // slice a tensor
// args: tensor, slice or slice tuple // args: tensor, slice or slice tuple
const std::string op_name = std::string("UnpackCall"); const std::string op_name = std::string("UnpackCall");
@ -70,7 +70,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
AnfNodePtr para_dict = ret_graph->add_parameter(); AnfNodePtr para_dict = ret_graph->add_parameter();
auto dict_elems = arg_dict->elements(); auto dict_elems = arg_dict->elements();
(void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
[ret_graph, para_dict](const AbstractAttribute& item) { [ret_graph, para_dict](const AbstractAttribute &item) {
auto dict_get_item = ret_graph->NewCNode( auto dict_get_item = ret_graph->NewCNode(
{NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)});
return ret_graph->NewCNode( return ret_graph->NewCNode(
@ -85,9 +85,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
return ret_graph; return ret_graph;
} }
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) {
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_") (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
.def(py::init<std::string&>()); .def(py::init<std::string &>());
})); }));
} // namespace prim } // namespace prim

View File

@ -40,11 +40,11 @@ namespace prim {
// and generate positional parameters and key-value pairs for function. // and generate positional parameters and key-value pairs for function.
class UnpackCall : public MetaFuncGraph { class UnpackCall : public MetaFuncGraph {
public: public:
explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {}
~UnpackCall() override = default; ~UnpackCall() override = default;
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; }
}; };
using UnpackCallPtr = std::shared_ptr<UnpackCall>; using UnpackCallPtr = std::shared_ptr<UnpackCall>;

View File

@ -36,7 +36,7 @@ namespace prim {
using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuple;
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// zip operation: // zip operation:
// input: tuple arguments // input: tuple arguments
// output: tuple of items of input iterated on every input // output: tuple of items of input iterated on every input
@ -44,7 +44,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
MS_LOG(EXCEPTION) << "zip arguments input should not be empty"; MS_LOG(EXCEPTION) << "zip arguments input should not be empty";
} }
auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr& abs) -> bool { auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs); MS_EXCEPTION_IF_NULL(abs);
return abs->isa<AbstractTuple>(); return abs->isa<AbstractTuple>();
}); });
@ -53,7 +53,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
} }
auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(), auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(),
[](const AbstractBasePtr& x, const AbstractBasePtr& y) { [](const AbstractBasePtr &x, const AbstractBasePtr &y) {
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
}); });
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
@ -81,10 +81,10 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
return ret_graph; return ret_graph;
} }
REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) {
(void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m, (void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m,
"ZipOperation_") "ZipOperation_")
.def(py::init<std::string&>()); .def(py::init<std::string &>());
})); }));
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -42,15 +42,15 @@ using AbstractTuplePtr = abstract::AbstractTuplePtr;
class ZipOperation : public MetaFuncGraph { class ZipOperation : public MetaFuncGraph {
public: public:
explicit ZipOperation(const std::string& name) : MetaFuncGraph(name) {} explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {}
~ZipOperation() override = default; ~ZipOperation() override = default;
MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend std::ostream& operator<<(std::ostream& os, const ZipOperation& op) { friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) {
os << op.name_; os << op.name_;
return os; return os;
} }
friend bool operator==(const ZipOperation& lhs, const ZipOperation& rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; }
}; };
using ZipOperationPtr = std::shared_ptr<ZipOperation>; using ZipOperationPtr = std::shared_ptr<ZipOperation>;
} // namespace prim } // namespace prim

View File

@ -238,7 +238,7 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) { ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) {
py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
ValuePtr node = nullptr; ValuePtr node = nullptr;
bool succ = parse::ConvertData(obj, &node); bool succ = parse::ConvertData(obj, &node);

View File

@ -26,8 +26,8 @@
namespace mindspore { namespace mindspore {
// namespace to support primitive operators // namespace to support primitive operators
namespace prim { namespace prim {
ValuePtr GetPythonOps(const std::string& op_name, ValuePtr GetPythonOps(const std::string &op_name,
const std::string& module_name = "mindspore._extends.parse.standard_method"); const std::string &module_name = "mindspore._extends.parse.standard_method");
// Arithmetic // Arithmetic
extern const PrimitivePtr kPrimScalarAdd; extern const PrimitivePtr kPrimScalarAdd;
@ -241,7 +241,7 @@ extern const PrimitivePtr kPrimVirtualDataset;
class DoSignaturePrimitive : public Primitive { class DoSignaturePrimitive : public Primitive {
public: public:
explicit DoSignaturePrimitive(const std::string& name, const ValuePtr& function) explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
: Primitive("S-Prim-" + name), function_(function) {} : Primitive("S-Prim-" + name), function_(function) {}
~DoSignaturePrimitive() override = default; ~DoSignaturePrimitive() override = default;
@ -257,7 +257,7 @@ using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
class UnpackGraphPrimitive : public Primitive { class UnpackGraphPrimitive : public Primitive {
public: public:
explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)
: Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {}
~UnpackGraphPrimitive() override = default; ~UnpackGraphPrimitive() override = default;
MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive)

View File

@ -54,7 +54,7 @@ PrimToFunction::PrimToFunction()
{"scalar_sub", kPrimTypeTwoArgs}, {"scalar_sub", kPrimTypeTwoArgs},
{"scalar_floordiv", kPrimTypeTwoArgs}}) {} {"scalar_floordiv", kPrimTypeTwoArgs}}) {}
bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const {
bool result = false; bool result = false;
if (func != nullptr) { if (func != nullptr) {
@ -79,7 +79,7 @@ bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const fu
return result; return result;
} }
int PrimToFunction::GetPrimType(const PrimitivePtr& prim) const { int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
int prim_type = static_cast<int>(kPrimTypeUnknown); int prim_type = static_cast<int>(kPrimTypeUnknown);

View File

@ -41,21 +41,21 @@ class PrimToFunction;
class PrimToFunction { class PrimToFunction {
public: public:
// Return a thread-safe singleton instance // Return a thread-safe singleton instance
static PrimToFunction& GetInstance() { static PrimToFunction &GetInstance() {
static PrimToFunction instance; static PrimToFunction instance;
return instance; return instance;
} }
PrimToFunction(const PrimToFunction&) = delete; PrimToFunction(const PrimToFunction &) = delete;
PrimToFunction& operator=(const PrimToFunction&) = delete; PrimToFunction &operator=(const PrimToFunction &) = delete;
~PrimToFunction() = default; ~PrimToFunction() = default;
// Get the args and return value for a primitive instance. // Get the args and return value for a primitive instance.
bool GetFunction(const PrimitivePtr& prim, FunctionPtr* func) const; bool GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const;
private: private:
PrimToFunction(); PrimToFunction();
// Get the number of primitive arguments // Get the number of primitive arguments
int GetPrimType(const PrimitivePtr& prim) const; int GetPrimType(const PrimitivePtr &prim) const;
const std::unordered_map<std::string, int> prim_func_type_map_; const std::unordered_map<std::string, int> prim_func_type_map_;
}; };
} // namespace prim } // namespace prim

View File

@ -24,7 +24,7 @@
namespace mindspore { namespace mindspore {
namespace ad { namespace ad {
Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller) Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller)
: primal_(primal), caller_(caller), dout_(nullptr) { : primal_(primal), caller_(caller), dout_(nullptr) {
if (k != nullptr) { if (k != nullptr) {
k_ = k; k_ = k;
@ -43,13 +43,13 @@ Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphP
AnfNodePtr Adjoint::k() { return k_; } AnfNodePtr Adjoint::k() { return k_; }
void Adjoint::RegisterKUser(const CNodePtr& user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); }
void Adjoint::UpdateK(const AnfNodePtr& new_k) { void Adjoint::UpdateK(const AnfNodePtr &new_k) {
MS_EXCEPTION_IF_NULL(new_k); MS_EXCEPTION_IF_NULL(new_k);
MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString();
// In recursive case, it needs update. // In recursive case, it needs update.
for (auto& user : k_user_) { for (auto &user : k_user_) {
MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k"
<< new_k->ToString(); << new_k->ToString();
if (user.first->input(user.second) != k_) { if (user.first->input(user.second) != k_) {
@ -65,11 +65,11 @@ AnfNodePtr Adjoint::primal() { return primal_; }
AnfNodePtr Adjoint::dout() { return dout_hole_; } AnfNodePtr Adjoint::dout() { return dout_hole_; }
void Adjoint::RegisterDoutUser(const CNodePtr& user, size_t index) { void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) {
dout_user_.emplace_back(std::make_pair(user, index)); dout_user_.emplace_back(std::make_pair(user, index));
} }
void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) {
if (dout_ != nullptr) { if (dout_ != nullptr) {
MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString();
auto add = prim::GetPythonOps("hyper_add"); auto add = prim::GetPythonOps("hyper_add");
@ -81,7 +81,7 @@ void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) {
void Adjoint::CallDoutHole() { void Adjoint::CallDoutHole() {
if (dout_ != nullptr) { if (dout_ != nullptr) {
for (auto& user : dout_user_) { for (auto &user : dout_user_) {
MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout "
<< dout_->ToString(); << dout_->ToString();
if (user.first->input(user.second) != dout_hole_) { if (user.first->input(user.second) != dout_hole_) {

View File

@ -28,15 +28,15 @@ namespace mindspore {
namespace ad { namespace ad {
class Adjoint { class Adjoint {
public: public:
Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller); Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller);
~Adjoint() = default; ~Adjoint() = default;
AnfNodePtr primal(); AnfNodePtr primal();
AnfNodePtr k(); AnfNodePtr k();
void UpdateK(const AnfNodePtr& k); void UpdateK(const AnfNodePtr &k);
void RegisterKUser(const CNodePtr& user, size_t index); void RegisterKUser(const CNodePtr &user, size_t index);
AnfNodePtr dout(); AnfNodePtr dout();
void AccumulateDout(const AnfNodePtr& dout_factor); void AccumulateDout(const AnfNodePtr &dout_factor);
void RegisterDoutUser(const CNodePtr& user, size_t index); void RegisterDoutUser(const CNodePtr &user, size_t index);
void CallDoutHole(); void CallDoutHole();
private: private:

View File

@ -36,7 +36,7 @@ using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuple;
static AbstractBasePtr Reabs(const AbstractBasePtr& t) { static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
if (t == nullptr) { if (t == nullptr) {
return nullptr; return nullptr;
} }
@ -47,14 +47,14 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) {
AbstractBasePtrList baselist; AbstractBasePtrList baselist;
auto attributes = abs_class->attributes(); auto attributes = abs_class->attributes();
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
[](const AbstractAttribute& item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist); res = std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractDictionary>()) { } else if (t->isa<AbstractDictionary>()) {
auto abs_dict = dyn_cast<AbstractDictionary>(t); auto abs_dict = dyn_cast<AbstractDictionary>(t);
AbstractBasePtrList baselist; AbstractBasePtrList baselist;
auto elements = abs_dict->elements(); auto elements = abs_dict->elements();
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
[](const AbstractAttribute& item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist); res = std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractList>()) { } else if (t->isa<AbstractList>()) {
auto abs_dict = dyn_cast<AbstractList>(t); auto abs_dict = dyn_cast<AbstractList>(t);
@ -63,11 +63,11 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) {
return res; return res;
} }
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
const auto& inputs = node->inputs(); const auto &inputs = node->inputs();
// Inputs should be [getattr, data, attribute] // Inputs should be [getattr, data, attribute]
MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs.");
@ -86,9 +86,9 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) {
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
auto ct = dyn_cast<AbstractClass>(dt); auto ct = dyn_cast<AbstractClass>(dt);
const auto& cmap = ct->attributes(); const auto &cmap = ct->attributes();
int count = 0; int count = 0;
for (auto& item : cmap) { for (auto &item : cmap) {
if (cons_is_str && item.first == cons_str) { if (cons_is_str && item.first == cons_str) {
break; break;
} }
@ -102,12 +102,12 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
} }
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
// Inputs should be [dict_getitem, dict, item] // Inputs should be [dict_getitem, dict, item]
const auto& inputs = node->inputs(); const auto &inputs = node->inputs();
MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs.");
AnfNodePtr data = inputs[1]; AnfNodePtr data = inputs[1];
@ -124,9 +124,9 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) {
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
auto ct = dyn_cast<abstract::AbstractDictionary>(dt); auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
const auto& cmap = ct->elements(); const auto &cmap = ct->elements();
int count = 0; int count = 0;
for (auto& item : cmap) { for (auto &item : cmap) {
if (cons_is_str && item.first == cons_str) { if (cons_is_str && item.first == cons_str) {
break; break;
} }
@ -139,7 +139,7 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
} }
AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
@ -150,11 +150,11 @@ AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) {
return node->func_graph()->NewCNode(inputs); return node->func_graph()->NewCNode(inputs);
} }
AnfNodePtr ErasePartialNode(const CNodePtr& node) { AnfNodePtr ErasePartialNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
const auto& inputs = node->inputs(); const auto &inputs = node->inputs();
// Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs.");
@ -178,7 +178,7 @@ AnfNodePtr ErasePartialNode(const CNodePtr& node) {
return nullptr; return nullptr;
} }
AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
@ -189,11 +189,11 @@ AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) {
return node->func_graph()->NewCNode(inputs); return node->func_graph()->NewCNode(inputs);
} }
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
const auto& inputs = node->inputs(); const auto &inputs = node->inputs();
// Inputs should be [list_getitem, list, item] // Inputs should be [list_getitem, list, item]
if (inputs.size() < 3) { if (inputs.size() < 3) {
MS_LOG(EXCEPTION) << "Node's input number < 3."; MS_LOG(EXCEPTION) << "Node's input number < 3.";
@ -208,11 +208,11 @@ AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
} }
AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
const auto& inputs = node->inputs(); const auto &inputs = node->inputs();
// Inputs should be [list_setitem, list, index, item] // Inputs should be [list_setitem, list, index, item]
if (inputs.size() < 4) { if (inputs.size() < 4) {
MS_LOG(EXCEPTION) << "Node's input number < 4."; MS_LOG(EXCEPTION) << "Node's input number < 4.";
@ -225,36 +225,36 @@ AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
} }
AnfNodePtr EraseMakeDictNode(const CNodePtr& node) { AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
const auto& inputs = node->inputs(); const auto &inputs = node->inputs();
MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs");
return inputs[2]; return inputs[2];
} }
AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr& node) { AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
const auto& inputs = node->inputs(); const auto &inputs = node->inputs();
// Inputs should be [make_keyword_arg, key, value] // Inputs should be [make_keyword_arg, key, value]
MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs");
return inputs[2]; return inputs[2];
} }
AnfNodePtr EraseExtractKeywordArg(const CNodePtr& node) { AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
const auto& inputs = node->inputs(); const auto &inputs = node->inputs();
// Inputs should be [extract_keyword_arg, arg, key] // Inputs should be [extract_keyword_arg, arg, key]
MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs");
return inputs[2]; return inputs[2];
} }
ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int depth) { ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) {
const int DEPTH_MAX = 5; const int DEPTH_MAX = 5;
if (depth > DEPTH_MAX) { if (depth > DEPTH_MAX) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels.";
} }
std::vector<ValuePtr> elements; std::vector<ValuePtr> elements;
for (const auto& it : value_list->value()) { for (const auto &it : value_list->value()) {
ValuePtr value = nullptr; ValuePtr value = nullptr;
if (it->isa<ValueList>()) { if (it->isa<ValueList>()) {
value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1); value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
@ -266,7 +266,7 @@ ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int d
return std::make_shared<ValueTuple>(elements); return std::make_shared<ValueTuple>(elements);
} }
AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
ValuePtr value = node->value(); ValuePtr value = node->value();
auto value_list = value->cast<ValueListPtr>(); auto value_list = value->cast<ValueListPtr>();
@ -278,13 +278,13 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) {
// Convert class to Tuple // Convert class to Tuple
// Convert getattr to getitem // Convert getattr to getitem
// Convert make_record to make_tuple // Convert make_record to make_tuple
void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root); manager->AddFuncGraph(root);
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet all_node = manager->all_nodes(); AnfNodeSet all_node = manager->all_nodes();
for (auto& node : all_node) { for (auto &node : all_node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
AnfNodePtr new_node = nullptr; AnfNodePtr new_node = nullptr;
@ -320,20 +320,20 @@ void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr&
} }
} }
for (auto& node : manager->all_nodes()) { for (auto &node : manager->all_nodes()) {
auto ret = Reabs(node->abstract()); auto ret = Reabs(node->abstract());
node->set_abstract(ret); node->set_abstract(ret);
} }
} }
// expand tuples in graph parameters // expand tuples in graph parameters
static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, const FuncGraphPtr& func_graph, static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
const std::vector<AnfNodePtr>& params) { const std::vector<AnfNodePtr> &params) {
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_params; std::vector<AnfNodePtr> new_params;
for (const auto& param : params) { for (const auto &param : params) {
MS_EXCEPTION_IF_NULL(param); MS_EXCEPTION_IF_NULL(param);
auto param_abs = param->abstract(); auto param_abs = param->abstract();
MS_EXCEPTION_IF_NULL(param_abs); MS_EXCEPTION_IF_NULL(param_abs);
@ -350,7 +350,7 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con
std::vector<AnfNodePtr> new_param; std::vector<AnfNodePtr> new_param;
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)}; std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
auto abs_tuple = dyn_cast<AbstractTuple>(param_abs); auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
for (auto& elem : abs_tuple->elements()) { for (auto &elem : abs_tuple->elements()) {
auto np = std::make_shared<Parameter>(func_graph); auto np = std::make_shared<Parameter>(func_graph);
np->set_abstract(elem); np->set_abstract(elem);
new_param.emplace_back(np); new_param.emplace_back(np);
@ -366,11 +366,11 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con
} }
// expand tuples in graph applies // expand tuples in graph applies
static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const std::vector<AnfNodePtr>& inputs) { static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> new_inputs; std::vector<AnfNodePtr> new_inputs;
for (const auto& input : inputs) { for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
auto input_abs = input->abstract(); auto input_abs = input->abstract();
@ -391,7 +391,7 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st
int idx = 0; int idx = 0;
std::vector<AnfNodePtr> new_input; std::vector<AnfNodePtr> new_input;
auto abs_tuple = dyn_cast<AbstractTuple>(input_abs); auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
for (auto& elem : abs_tuple->elements()) { for (auto &elem : abs_tuple->elements()) {
auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx)); AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx));
c_node->input(2)->set_abstract(aptr); c_node->input(2)->set_abstract(aptr);
@ -416,19 +416,19 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st
// tuples in Graph's parameters: AbstractTuple (a, b, c) --> // tuples in Graph's parameters: AbstractTuple (a, b, c) -->
// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
// cppcheck-suppress unusedFunction // cppcheck-suppress unusedFunction
void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root); manager->AddFuncGraph(root);
// NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet all_node = manager->all_nodes(); AnfNodeSet all_node = manager->all_nodes();
for (auto& node : all_node) { for (auto &node : all_node) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) { if (cnode == nullptr) {
continue; continue;
} }
const auto& inputs = cnode->inputs(); const auto &inputs = cnode->inputs();
// Bypass the first input in inputs as it's fn. // Bypass the first input in inputs as it's fn.
if (!IsValueNode<Primitive>(inputs[0])) { if (!IsValueNode<Primitive>(inputs[0])) {
@ -466,7 +466,7 @@ void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) {
} }
FuncGraphSet all_graph = manager->func_graphs(); FuncGraphSet all_graph = manager->func_graphs();
for (auto& func_graph : all_graph) { for (auto &func_graph : all_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
manager->SetParameters(func_graph, expand_p); manager->SetParameters(func_graph, expand_p);

View File

@ -22,7 +22,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
// Automatically adding control depend based on effect order and side effect analysis. // Automatically adding control depend based on effect order and side effect analysis.
void AddControlDepend(const FuncGraphPtr& graph); void AddControlDepend(const FuncGraphPtr &graph);
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ #endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_

View File

@ -44,7 +44,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func
nodes.push_back(func_node); nodes.push_back(func_node);
// {unpackcall, {GradOperation, ...}, args...} // {unpackcall, {GradOperation, ...}, args...}
std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr& node) { return node; }); [](const AnfNodePtr &node) { return node; });
unpack_graph_node = func_graph->NewCNode(nodes); unpack_graph_node = func_graph->NewCNode(nodes);
} else { } else {
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false); auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false);
@ -52,14 +52,14 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func
nodes.push_back(func_node); nodes.push_back(func_node);
// {{GradOperation, ...}, args...} // {{GradOperation, ...}, args...}
std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr& node) { return node; }); [](const AnfNodePtr &node) { return node; });
unpack_graph_node = func_graph->NewCNode(nodes); unpack_graph_node = func_graph->NewCNode(nodes);
} }
return unpack_graph_node; return unpack_graph_node;
} }
// get metagraph of value node // get metagraph of value node
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
ValuePtr value; ValuePtr value;
if (IsValueNode<prim::DoSignaturePrimitive>(node)) { if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function(); value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
@ -73,7 +73,7 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) {
} }
// check if node is a specific metafuncgraph op // check if node is a specific metafuncgraph op
bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) {
if (node != nullptr) { if (node != nullptr) {
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
if (meta_func_graph_ptr == nullptr) { if (meta_func_graph_ptr == nullptr) {
@ -89,7 +89,7 @@ bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_gr
// {{GradOperation, g, w}, Ys} // {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys} // {UnPackCall, {GradOperation, g, w}, Ys}
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (!node->isa<CNode>() || node->func_graph() == nullptr) { if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr; return nullptr;
} }

View File

@ -31,20 +31,20 @@
namespace mindspore { namespace mindspore {
/* namespace to support opt */ /* namespace to support opt */
namespace opt { namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim, SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction& renorm_action) { const RenormAction &renorm_action) {
auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action); return std::make_shared<Substitution>(transform, name, fn, renorm_action);
} }
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) { const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
auto fn = [prims](const AnfNodePtr& node) -> bool { auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return false; return false;
} }
for (auto& prim : prims) { for (auto &prim : prims) {
if (IsPrimitiveCNode(node, prim)) { if (IsPrimitiveCNode(node, prim)) {
return true; return true;
} }
@ -55,12 +55,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::
return std::make_shared<Substitution>(transform, name, fn, renorm_action); return std::make_shared<Substitution>(transform, name, fn, renorm_action);
} }
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
const PredicateFuncType& predicate, const RenormAction& renorm_action) { const PredicateFuncType &predicate, const RenormAction &renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action); return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
} }
AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const {
#ifdef ENABLE_PROFILE #ifdef ENABLE_PROFILE
double t = GetTime(); double t = GetTime();
#endif #endif
@ -88,8 +88,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode
return result; return result;
} }
bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNodePtr& root_node, bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
const SubstitutionPtr& transform) const { const SubstitutionPtr &transform) const {
FuncGraphManagerPtr manager = optimizer->manager(); FuncGraphManagerPtr manager = optimizer->manager();
std::unordered_set<AnfNodePtr> seen_node; std::unordered_set<AnfNodePtr> seen_node;
std::deque<AnfNodePtr> todo{root_node}; std::deque<AnfNodePtr> todo{root_node};
@ -131,13 +131,13 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo
} }
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto& inputs = node->cast<CNodePtr>()->inputs(); auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
} }
auto& node_users = manager->node_users(); auto &node_users = manager->node_users();
if (change && node_users.find(node) != node_users.end()) { if (change && node_users.find(node) != node_users.end()) {
for (auto& use : node_users[node]) { for (auto &use : node_users[node]) {
auto use_node = use.first; auto use_node = use.first;
todo.push_back(use_node); todo.push_back(use_node);
if (seen_node.find(use_node) != seen_node.end()) { if (seen_node.find(use_node) != seen_node.end()) {
@ -152,7 +152,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo
return changes; return changes;
} }
bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const OptimizerPtr& optimizer) const { bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = optimizer->manager(); FuncGraphManagerPtr manager = optimizer->manager();
@ -163,7 +163,7 @@ bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const Optimize
do { do {
loop = false; loop = false;
for (auto const& transform : list_) { for (auto const &transform : list_) {
auto change = ApplyTransform(optimizer, func_graph->output(), transform); auto change = ApplyTransform(optimizer, func_graph->output(), transform);
changes = changes || change; changes = changes || change;
loop = loop || change; loop = loop || change;

View File

@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t recursive_times = 0) { std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint32_t recursive_times = 0) {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES; << MAX_RECURSIVE_CALL_TIMES;
@ -39,7 +39,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto node_set = manager->node_users()[para]; auto node_set = manager->node_users()[para];
std::unordered_set<CNodePtr> cnode_set; std::unordered_set<CNodePtr> cnode_set;
for (auto& node_pair : node_set) { for (auto &node_pair : node_set) {
auto cnode = node_pair.first->cast<CNodePtr>(); auto cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) { if (!IsValueNode<Primitive>(cnode->input(0))) {
@ -54,7 +54,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
(void)cnode_set.emplace(cnode); (void)cnode_set.emplace(cnode);
} else { } else {
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
for (auto& cnode_sub : cnode_set_sub) { for (auto &cnode_sub : cnode_set_sub) {
(void)cnode_set.emplace(cnode_sub); (void)cnode_set.emplace(cnode_sub);
} }
} }
@ -63,8 +63,8 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
} }
Status AllreduceFusion::AddNodeToGraph() { Status AllreduceFusion::AddNodeToGraph() {
const auto& parameters = root_graph_->parameters(); const auto &parameters = root_graph_->parameters();
for (auto& parameter : parameters) { for (auto &parameter : parameters) {
if (!ParameterRequireGrad(parameter)) { if (!ParameterRequireGrad(parameter)) {
continue; continue;
} }
@ -72,7 +72,7 @@ Status AllreduceFusion::AddNodeToGraph() {
if (cnode_set.empty()) { if (cnode_set.empty()) {
continue; continue;
} }
for (auto& cnode : cnode_set) { for (auto &cnode : cnode_set) {
MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); MS_LOG(DEBUG) << "AddNode " << cnode->DebugString();
if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) {
MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString();
@ -83,7 +83,7 @@ Status AllreduceFusion::AddNodeToGraph() {
return SUCCESS; return SUCCESS;
} }
CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursive_times) const { CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES; << MAX_RECURSIVE_CALL_TIMES;
@ -110,30 +110,30 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursi
return cnode_dist; return cnode_dist;
} else { } else {
auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1);
for (auto& ele : cnode_dist_next) { for (auto &ele : cnode_dist_next) {
cnode_dist[ele.first] = cost + ele.second; cnode_dist[ele.first] = cost + ele.second;
} }
} }
} else { } else {
auto cnode_dist_next = FindNextCNodes(cnode); auto cnode_dist_next = FindNextCNodes(cnode);
for (auto& ele : cnode_dist_next) { for (auto &ele : cnode_dist_next) {
cnode_dist[ele.first] = ele.second; cnode_dist[ele.first] = ele.second;
} }
} }
return cnode_dist; return cnode_dist;
} }
CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recursive_times) const { CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES; << MAX_RECURSIVE_CALL_TIMES;
} }
const auto& from_inputs = from->inputs(); const auto &from_inputs = from->inputs();
std::unordered_map<CNodePtr, double> dist_map; std::unordered_map<CNodePtr, double> dist_map;
MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs";
for (auto& input_node : from_inputs) { for (auto &input_node : from_inputs) {
auto cnode_dist = FindCNode(input_node, recursive_times + 1); auto cnode_dist = FindCNode(input_node, recursive_times + 1);
for (auto& ele : cnode_dist) { for (auto &ele : cnode_dist) {
(void)dist_map.emplace(ele); (void)dist_map.emplace(ele);
} }
} }
@ -142,11 +142,11 @@ CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recu
Status AllreduceFusion::AddEdgeToGraph() { Status AllreduceFusion::AddEdgeToGraph() {
std::unordered_map<CNodePtr, int32_t> cnode_state_map; std::unordered_map<CNodePtr, int32_t> cnode_state_map;
const auto& cnodes = allreduce_graph_.cnode_set(); const auto &cnodes = allreduce_graph_.cnode_set();
for (auto& cnode : cnodes) { for (auto &cnode : cnodes) {
cnode_state_map[cnode] = 0; cnode_state_map[cnode] = 0;
} }
const auto& head_cnode = allreduce_graph_.head_cnode(); const auto &head_cnode = allreduce_graph_.head_cnode();
std::queue<CNodePtr> cnode_queue; std::queue<CNodePtr> cnode_queue;
cnode_queue.emplace(head_cnode); cnode_queue.emplace(head_cnode);
cnode_state_map[head_cnode] = 1; cnode_state_map[head_cnode] = 1;
@ -156,9 +156,9 @@ Status AllreduceFusion::AddEdgeToGraph() {
cnode_queue.pop(); cnode_queue.pop();
cnode_state_map[cur_cnode] = 2; cnode_state_map[cur_cnode] = 2;
auto next = FindNextCNodes(cur_cnode); auto next = FindNextCNodes(cur_cnode);
for (auto& ele : next) { for (auto &ele : next) {
auto& cnode = ele.first; auto &cnode = ele.first;
auto& dist = ele.second; auto &dist = ele.second;
if (cnode_state_map[cnode] == 0) { if (cnode_state_map[cnode] == 0) {
cnode_queue.emplace(cnode); cnode_queue.emplace(cnode);
cnode_state_map[cnode] = 1; cnode_state_map[cnode] = 1;
@ -173,7 +173,7 @@ Status AllreduceFusion::AddEdgeToGraph() {
return SUCCESS; return SUCCESS;
} }
std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_times = 0) { std::vector<CNodePtr> FindMirror(const AnfNodePtr &para, uint32_t recursive_times = 0) {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES; << MAX_RECURSIVE_CALL_TIMES;
@ -184,7 +184,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[para]; AnfNodeIndexSet node_set = manager->node_users()[para];
std::vector<CNodePtr> cnode_list; std::vector<CNodePtr> cnode_list;
for (auto& node_pair : node_set) { for (auto &node_pair : node_set) {
auto cnode = node_pair.first->cast<CNodePtr>(); auto cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) { if (!IsValueNode<Primitive>(cnode->input(0))) {
@ -210,7 +210,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time
return cnode_list; return cnode_list;
} }
void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::string& parameter_name) { void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string &parameter_name) {
MS_EXCEPTION_IF_NULL(mirror_cnode); MS_EXCEPTION_IF_NULL(mirror_cnode);
MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0)); auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
@ -227,14 +227,14 @@ void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::st
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name))); (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
} }
Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { Status FindMirrorAndSetFusion(const AnfNodePtr &para, int32_t fusion) {
auto mirror_cnodes = FindMirror(para); auto mirror_cnodes = FindMirror(para);
if (mirror_cnodes.empty()) { if (mirror_cnodes.empty()) {
MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found.";
return SUCCESS; return SUCCESS;
} }
if (mirror_cnodes.size() > 2) { if (mirror_cnodes.size() > 2) {
for (auto& mirror_cnode : mirror_cnodes) { for (auto &mirror_cnode : mirror_cnodes) {
MS_EXCEPTION_IF_NULL(mirror_cnode); MS_EXCEPTION_IF_NULL(mirror_cnode);
MS_LOG(INFO) << mirror_cnode->DebugString(); MS_LOG(INFO) << mirror_cnode->DebugString();
} }
@ -243,15 +243,15 @@ Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) {
<< "Mirror CNode found."; << "Mirror CNode found.";
return FAILED; return FAILED;
} }
for (auto& mirror_cnode : mirror_cnodes) { for (auto &mirror_cnode : mirror_cnodes) {
auto parameter_name = ParameterName(para); auto parameter_name = ParameterName(para);
SetMirrorFusion(mirror_cnode, fusion, parameter_name); SetMirrorFusion(mirror_cnode, fusion, parameter_name);
} }
return SUCCESS; return SUCCESS;
} }
Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusion) { Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> &paras, int32_t fusion) {
for (auto& param_node : paras) { for (auto &param_node : paras) {
if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) {
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
return FAILED; return FAILED;
@ -260,7 +260,7 @@ Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusi
return SUCCESS; return SUCCESS;
} }
Status AllreduceFusion::SetFusion(const std::vector<double>& cost_map) { Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) {
if (cost_map.size() < 2) { if (cost_map.size() < 2) {
MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size();
return FAILED; return FAILED;
@ -386,7 +386,7 @@ Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) {
return SetFusionByBackwardCompAndAllreduceTime(); return SetFusionByBackwardCompAndAllreduceTime();
} }
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) { Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
if (ret == nullptr) { if (ret == nullptr) {
MS_LOG(ERROR) << "ret is nullptr."; MS_LOG(ERROR) << "ret is nullptr.";
return FAILED; return FAILED;

View File

@ -50,15 +50,15 @@ class AllreduceFusion {
allreduce_bandwidth_(0), allreduce_bandwidth_(0),
computation_time_parameter_(0) {} computation_time_parameter_(0) {}
virtual ~AllreduceFusion() = default; virtual ~AllreduceFusion() = default;
Status ProcessAllreduceFusion(const CNodePtr& ret); Status ProcessAllreduceFusion(const CNodePtr &ret);
private: private:
Status AddNodeToGraph(); Status AddNodeToGraph();
CNodeCostMap FindCNode(const AnfNodePtr& from, uint32_t recursive_times = 0) const; CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const;
CNodeCostMap FindNextCNodes(const CNodePtr& from, uint32_t recursive_times = 0) const; CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const;
Status AddEdgeToGraph(); Status AddEdgeToGraph();
std::vector<double> GenerateCostMap(int32_t fusion_times, double tail_percent) const; std::vector<double> GenerateCostMap(int32_t fusion_times, double tail_percent) const;
Status SetFusion(const std::vector<double>& cost_map); Status SetFusion(const std::vector<double> &cost_map);
Status SetFusionByAlgorithm(int32_t algorithm); Status SetFusionByAlgorithm(int32_t algorithm);
Status SetFusionByBackwardCompTime(); Status SetFusionByBackwardCompTime();
Status SetFusionByBackwardCompAndAllreduceTime(); Status SetFusionByBackwardCompAndAllreduceTime();

View File

@ -23,7 +23,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr &para) {
AllreduceNodePtr arnode; AllreduceNodePtr arnode;
auto cnode_emplace_return = cnode_set_.emplace(node); auto cnode_emplace_return = cnode_set_.emplace(node);
if (!cnode_emplace_return.second) { if (!cnode_emplace_return.second) {
@ -64,7 +64,7 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
return SUCCESS; return SUCCESS;
} }
Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double dist) { Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) {
auto from_arnode_iter = cnode_arnode_map_.find(from); auto from_arnode_iter = cnode_arnode_map_.find(from);
if (from_arnode_iter == cnode_arnode_map_.end()) { if (from_arnode_iter == cnode_arnode_map_.end()) {
MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
@ -94,14 +94,14 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double
return SUCCESS; return SUCCESS;
} }
bool AllreduceGraph::NodeInGraph(const CNodePtr& node) const { bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const {
auto cnode_iter = cnode_set_.find(node); auto cnode_iter = cnode_set_.find(node);
return !(cnode_iter == cnode_set_.end()); return !(cnode_iter == cnode_set_.end());
} }
std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) { std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
std::vector<AnfNodePtr> nodes; std::vector<AnfNodePtr> nodes;
for (auto& cnode_arnode : cnode_arnode_map_) { for (auto &cnode_arnode : cnode_arnode_map_) {
MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
<< ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
<< " curr_para_size: " << cnode_arnode.second->curr_para_size(); << " curr_para_size: " << cnode_arnode.second->curr_para_size();
@ -117,7 +117,7 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou
std::vector<AnfNodePtr> nodes; std::vector<AnfNodePtr> nodes;
double cur_para_size = 0; double cur_para_size = 0;
double from = to; double from = to;
for (auto& arnode : arnode_vec_) { for (auto &arnode : arnode_vec_) {
if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) {
continue; continue;
} }
@ -135,14 +135,14 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou
void AllreduceGraph::PrintCNodeSet() const { void AllreduceGraph::PrintCNodeSet() const {
MS_LOG(INFO) << "CNodeSet:"; MS_LOG(INFO) << "CNodeSet:";
for (auto& cnode : cnode_set_) { for (auto &cnode : cnode_set_) {
MS_LOG(INFO) << cnode->DebugString(); MS_LOG(INFO) << cnode->DebugString();
} }
} }
void AllreduceGraph::PrintAllredueGraphInfo() const { void AllreduceGraph::PrintAllredueGraphInfo() const {
MS_LOG(INFO) << "max: " << max_; MS_LOG(INFO) << "max: " << max_;
for (auto& cnode_arnode : cnode_arnode_map_) { for (auto &cnode_arnode : cnode_arnode_map_) {
MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
MS_LOG(INFO) << "arnode info: "; MS_LOG(INFO) << "arnode info: ";
cnode_arnode.second->ToString(); cnode_arnode.second->ToString();
@ -151,21 +151,21 @@ void AllreduceGraph::PrintAllredueGraphInfo() const {
void AllreduceGraph::PrintArnodeVec() const { void AllreduceGraph::PrintArnodeVec() const {
MS_LOG(INFO) << "ArnodeVec:"; MS_LOG(INFO) << "ArnodeVec:";
for (auto& arnode : arnode_vec_) { for (auto &arnode : arnode_vec_) {
arnode.ToString(); arnode.ToString();
} }
} }
void AllreduceGraph::PrintArnodeSet() const { void AllreduceGraph::PrintArnodeSet() const {
MS_LOG(INFO) << "ArnodeSet:"; MS_LOG(INFO) << "ArnodeSet:";
for (auto& arnode : arnode_set_) { for (auto &arnode : arnode_set_) {
arnode->ToString(); arnode->ToString();
} }
} }
void AllreduceGraph::SortArnode() { void AllreduceGraph::SortArnode() {
arnode_vec_.clear(); arnode_vec_.clear();
for (auto& node : arnode_set_) { for (auto &node : arnode_set_) {
arnode_vec_.emplace_back(*node); arnode_vec_.emplace_back(*node);
} }
std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
@ -173,8 +173,8 @@ void AllreduceGraph::SortArnode() {
Status AllreduceGraph::RemoveExtraParas() { Status AllreduceGraph::RemoveExtraParas() {
std::unordered_set<AnfNodePtr> para_map; std::unordered_set<AnfNodePtr> para_map;
for (auto& node : arnode_vec_) { for (auto &node : arnode_vec_) {
for (auto& para : node.paras()) { for (auto &para : node.paras()) {
auto emplac_result = para_map.emplace(para); auto emplac_result = para_map.emplace(para);
if (!emplac_result.second) { if (!emplac_result.second) {
MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
@ -188,7 +188,7 @@ Status AllreduceGraph::RemoveExtraParas() {
return SUCCESS; return SUCCESS;
} }
Status AllreduceGraph::set_head_cnode(const CNodePtr& node) { Status AllreduceGraph::set_head_cnode(const CNodePtr &node) {
auto arnode = std::make_shared<AllreduceNode>(AllreduceNode()); auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
if (arnode->Init(node) != SUCCESS) { if (arnode->Init(node) != SUCCESS) {
MS_LOG(ERROR) << "AllreduceNode Init failed"; MS_LOG(ERROR) << "AllreduceNode Init failed";

View File

@ -42,9 +42,9 @@ class AllreduceGraph {
cnode_arnode_map_(), cnode_arnode_map_(),
max_(0) {} max_(0) {}
virtual ~AllreduceGraph() = default; virtual ~AllreduceGraph() = default;
Status AddNode(const CNodePtr& node, const AnfNodePtr& para); Status AddNode(const CNodePtr &node, const AnfNodePtr &para);
Status AddEdge(const CNodePtr& from, const CNodePtr& to, double dist); Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist);
bool NodeInGraph(const CNodePtr& node) const; bool NodeInGraph(const CNodePtr &node) const;
std::vector<AnfNodePtr> GetParaByCost(double from, double to); std::vector<AnfNodePtr> GetParaByCost(double from, double to);
// Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is
// over para_size. // over para_size.
@ -60,9 +60,9 @@ class AllreduceGraph {
void PrintAllredueGraphInfo() const; void PrintAllredueGraphInfo() const;
void PrintArnodeVec() const; void PrintArnodeVec() const;
void PrintArnodeSet() const; void PrintArnodeSet() const;
const std::unordered_set<CNodePtr>& cnode_set() const { return cnode_set_; } const std::unordered_set<CNodePtr> &cnode_set() const { return cnode_set_; }
CNodePtr head_cnode() const { return head_cnode_; } CNodePtr head_cnode() const { return head_cnode_; }
Status set_head_cnode(const CNodePtr& node); Status set_head_cnode(const CNodePtr &node);
double max() const { return max_; } double max() const { return max_; }
private: private:

Some files were not shown because too many files have changed in this diff Show More