forked from OSSInnovation/mindspore
update clang format rule
This commit is contained in:
parent
31a12009dd
commit
c2b3360d69
|
@ -94,7 +94,7 @@ PenaltyBreakString: 1000
|
||||||
PenaltyBreakTemplateDeclaration: 10
|
PenaltyBreakTemplateDeclaration: 10
|
||||||
PenaltyExcessCharacter: 1000000
|
PenaltyExcessCharacter: 1000000
|
||||||
PenaltyReturnTypeOnItsOwnLine: 200
|
PenaltyReturnTypeOnItsOwnLine: 200
|
||||||
PointerAlignment: Left
|
PointerAlignment: Right
|
||||||
RawStringFormats:
|
RawStringFormats:
|
||||||
- Language: Cpp
|
- Language: Cpp
|
||||||
Delimiters:
|
Delimiters:
|
||||||
|
|
|
@ -23,7 +23,7 @@ namespace common {
|
||||||
const int CACHED_STR_NUM = 1 << 8;
|
const int CACHED_STR_NUM = 1 << 8;
|
||||||
const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
|
const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
|
||||||
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
|
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
|
||||||
const char* SafeCStr(const std::string&& str) {
|
const char *SafeCStr(const std::string &&str) {
|
||||||
static std::atomic<uint32_t> index{0};
|
static std::atomic<uint32_t> index{0};
|
||||||
uint32_t cur_index = index++;
|
uint32_t cur_index = index++;
|
||||||
cur_index = cur_index & CACHED_STR_MASK;
|
cur_index = cur_index & CACHED_STR_MASK;
|
||||||
|
|
|
@ -21,16 +21,16 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#define DISABLE_COPY_AND_ASSIGN(ClassType) \
|
#define DISABLE_COPY_AND_ASSIGN(ClassType) \
|
||||||
ClassType(const ClassType&) = delete; \
|
ClassType(const ClassType &) = delete; \
|
||||||
ClassType& operator=(const ClassType&) = delete;
|
ClassType &operator=(const ClassType &) = delete;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace common {
|
namespace common {
|
||||||
inline const char* SafeCStr(const std::string& str) { return str.c_str(); }
|
inline const char *SafeCStr(const std::string &str) { return str.c_str(); }
|
||||||
const char* SafeCStr(const std::string&& str);
|
const char *SafeCStr(const std::string &&str);
|
||||||
|
|
||||||
static inline std::string GetEnv(const std::string& envvar) {
|
static inline std::string GetEnv(const std::string &envvar) {
|
||||||
const char* value = ::getenv(envvar.c_str());
|
const char *value = ::getenv(envvar.c_str());
|
||||||
|
|
||||||
if (value == nullptr) {
|
if (value == nullptr) {
|
||||||
return std::string();
|
return std::string();
|
||||||
|
|
|
@ -34,11 +34,11 @@ class DecodeOp : public TensorOp {
|
||||||
|
|
||||||
~DecodeOp() = default;
|
~DecodeOp() = default;
|
||||||
|
|
||||||
Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override;
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
void Print(std::ostream& out) const override { out << "DecodeOp"; }
|
void Print(std::ostream &out) const override { out << "DecodeOp"; }
|
||||||
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override;
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override;
|
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool is_rgb_format_ = true;
|
bool is_rgb_format_ = true;
|
||||||
|
|
|
@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int
|
||||||
rnd_.seed(seed_);
|
rnd_.seed(seed_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input,
|
Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
||||||
std::vector<std::shared_ptr<Tensor>>* output) {
|
std::vector<std::shared_ptr<Tensor>> *output) {
|
||||||
IO_CHECK_VECTOR(input, output);
|
IO_CHECK_VECTOR(input, output);
|
||||||
if (input.size() != NumInput())
|
if (input.size() != NumInput())
|
||||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5");
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5");
|
||||||
|
@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tenso
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs,
|
Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape> &inputs,
|
||||||
std::vector<TensorShape>& outputs) {
|
std::vector<TensorShape> &outputs) {
|
||||||
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||||
outputs.clear();
|
outputs.clear();
|
||||||
TensorShape out = TensorShape{-1, -1};
|
TensorShape out = TensorShape{-1, -1};
|
||||||
|
@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inp
|
||||||
if (!outputs.empty()) return Status::OK();
|
if (!outputs.empty()) return Status::OK();
|
||||||
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
|
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
|
||||||
}
|
}
|
||||||
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) {
|
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||||
outputs[0] = inputs[0];
|
outputs[0] = inputs[0];
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp {
|
||||||
|
|
||||||
~DistortBoundingBoxCropOp() override = default;
|
~DistortBoundingBoxCropOp() override = default;
|
||||||
|
|
||||||
void Print(std::ostream& out) const override {
|
void Print(std::ostream &out) const override {
|
||||||
out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_;
|
out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Compute(const std::vector<std::shared_ptr<Tensor>>& input,
|
Status Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
||||||
std::vector<std::shared_ptr<Tensor>>* output) override;
|
std::vector<std::shared_ptr<Tensor>> *output) override;
|
||||||
|
|
||||||
uint32_t NumInput() override { return 5; }
|
uint32_t NumInput() override { return 5; }
|
||||||
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override;
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override;
|
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t max_attempts_;
|
int32_t max_attempts_;
|
||||||
|
|
|
@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
|
||||||
rnd_.seed(GetSeed());
|
rnd_.seed(GetSeed());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) {
|
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal");
|
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal");
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std:
|
||||||
(void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width);
|
(void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width);
|
||||||
return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_);
|
return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_);
|
||||||
}
|
}
|
||||||
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) {
|
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||||
outputs.clear();
|
outputs.clear();
|
||||||
TensorShape out = TensorShape{target_height_, target_width_};
|
TensorShape out = TensorShape{target_height_, target_width_};
|
||||||
|
@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs
|
||||||
if (!outputs.empty()) return Status::OK();
|
if (!outputs.empty()) return Status::OK();
|
||||||
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
|
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
|
||||||
}
|
}
|
||||||
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) {
|
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) {
|
||||||
double scale, aspect;
|
double scale, aspect;
|
||||||
*crop_width = w_in;
|
*crop_width = w_in;
|
||||||
*crop_height = h_in;
|
*crop_height = h_in;
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
constexpr char PARALLEL_STRATEGY[] = "strategy";
|
constexpr char PARALLEL_STRATEGY[] = "strategy";
|
||||||
void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false);
|
void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false);
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF;
|
||||||
// get MindSpore Intermediate Representation Path
|
// get MindSpore Intermediate Representation Path
|
||||||
std::string GetMsIrPath(void) {
|
std::string GetMsIrPath(void) {
|
||||||
std::string path;
|
std::string path;
|
||||||
const char* path_ptr = getenv("MS_IR_PATH");
|
const char *path_ptr = getenv("MS_IR_PATH");
|
||||||
if (path_ptr != nullptr) {
|
if (path_ptr != nullptr) {
|
||||||
path = path_ptr;
|
path = path_ptr;
|
||||||
char real_path[PATH_MAX] = {0};
|
char real_path[PATH_MAX] = {0};
|
||||||
|
@ -62,13 +62,13 @@ std::string GetMsIrPath(void) {
|
||||||
return path;
|
return path;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string dump_obj(const py::object& obj, const std::string& path) {
|
std::string dump_obj(const py::object &obj, const std::string &path) {
|
||||||
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||||
py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path));
|
py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path));
|
||||||
return py::str(name);
|
return py::str(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object load_obj(const std::string& path) {
|
py::object load_obj(const std::string &path) {
|
||||||
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||||
py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path));
|
py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path));
|
||||||
return obj;
|
return obj;
|
||||||
|
@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) {
|
||||||
|
|
||||||
// ============================================= MindSpore IR Exporter =============================================
|
// ============================================= MindSpore IR Exporter =============================================
|
||||||
|
|
||||||
std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
|
std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) {
|
||||||
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
|
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
|
||||||
TypePtr type = dyn_cast<Type>(nd->Type());
|
TypePtr type = dyn_cast<Type>(nd->Type());
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const {
|
std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const {
|
||||||
std::string pkl_path = GetMsIrPath();
|
std::string pkl_path = GetMsIrPath();
|
||||||
// if not specified env 'MS_IR_PATH', do not create any files
|
// if not specified env 'MS_IR_PATH', do not create any files
|
||||||
if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) {
|
if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) {
|
||||||
|
@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca
|
||||||
return file_prefix + file_name;
|
return file_prefix + file_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) {
|
int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) {
|
||||||
if (func_graph == nullptr || param == nullptr) {
|
if (func_graph == nullptr || param == nullptr) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&
|
||||||
|
|
||||||
// try to find index of parameter for SymbolicKeyInstance from all exported graphs
|
// try to find index of parameter for SymbolicKeyInstance from all exported graphs
|
||||||
// NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different
|
// NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different
|
||||||
int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) {
|
int AnfExporter::GetParamIndexFromExported(const AnfNodePtr ¶m) {
|
||||||
if (param == nullptr) {
|
if (param == nullptr) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ret = -1;
|
int ret = -1;
|
||||||
for (const auto& item : exported) {
|
for (const auto &item : exported) {
|
||||||
auto pram_iter = item.second.find(param);
|
auto pram_iter = item.second.find(param);
|
||||||
if (pram_iter != item.second.end()) {
|
if (pram_iter != item.second.end()) {
|
||||||
return pram_iter->second;
|
return pram_iter->second;
|
||||||
|
@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) {
|
std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
return GetValueText(fg, node->value());
|
return GetValueText(fg, node->value());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) {
|
std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) {
|
||||||
auto py_funcs = mt_func_graph->GetPyFunctions();
|
auto py_funcs = mt_func_graph->GetPyFunctions();
|
||||||
if (py_funcs.empty()) {
|
if (py_funcs.empty()) {
|
||||||
return "";
|
return "";
|
||||||
|
@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
|
||||||
|
|
||||||
oss << "{";
|
oss << "{";
|
||||||
bool is_first = true;
|
bool is_first = true;
|
||||||
for (const auto& py_func : py_funcs) {
|
for (const auto &py_func : py_funcs) {
|
||||||
if (is_first) {
|
if (is_first) {
|
||||||
is_first = false;
|
is_first = false;
|
||||||
} else {
|
} else {
|
||||||
|
@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
|
||||||
* ├── GradOperation
|
* ├── GradOperation
|
||||||
* └── TupleAdd
|
* └── TupleAdd
|
||||||
*/
|
*/
|
||||||
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) {
|
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
|
||||||
if (meta_func_graph == nullptr) {
|
if (meta_func_graph == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
|
std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
if (prim == nullptr) {
|
if (prim == nullptr) {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
|
@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
|
||||||
|
|
||||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||||
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
|
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
|
||||||
auto& func = do_signature->function();
|
auto &func = do_signature->function();
|
||||||
if (func->isa<Primitive>()) {
|
if (func->isa<Primitive>()) {
|
||||||
auto sig_prim = dyn_cast<Primitive>(func);
|
auto sig_prim = dyn_cast<Primitive>(func);
|
||||||
oss << sig_prim->GetAttrsText();
|
oss << sig_prim->GetAttrsText();
|
||||||
|
@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) {
|
std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
if (ns == nullptr) {
|
if (ns == nullptr) {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
|
@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph,
|
std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph,
|
||||||
const SymbolicKeyInstancePtr& sym_inst) {
|
const SymbolicKeyInstancePtr &sym_inst) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(sym_inst);
|
MS_EXCEPTION_IF_NULL(sym_inst);
|
||||||
AnfNodePtr sym_node = sym_inst->node();
|
AnfNodePtr sym_node = sym_inst->node();
|
||||||
|
@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
|
std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
// output ValueList, ValueTuple
|
// output ValueList, ValueTuple
|
||||||
ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value);
|
ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value);
|
||||||
|
@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
|
std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>();
|
ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>();
|
||||||
oss << "{";
|
oss << "{";
|
||||||
bool first_flag = true;
|
bool first_flag = true;
|
||||||
for (const auto& elem : dict->value()) {
|
for (const auto &elem : dict->value()) {
|
||||||
if (first_flag) {
|
if (first_flag) {
|
||||||
first_flag = false;
|
first_flag = false;
|
||||||
} else {
|
} else {
|
||||||
|
@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
|
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
|
||||||
if (check_integrity_) {
|
if (check_integrity_) {
|
||||||
|
@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr&
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
|
std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
bool is_null_ptr = (func_graph == nullptr || value == nullptr);
|
bool is_null_ptr = (func_graph == nullptr || value == nullptr);
|
||||||
if (is_null_ptr) {
|
if (is_null_ptr) {
|
||||||
|
@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu
|
||||||
}
|
}
|
||||||
|
|
||||||
// this function is used to output node in CNode's inputs
|
// this function is used to output node in CNode's inputs
|
||||||
std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
|
std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const std::map<AnfNodePtr, int>& apply_map) {
|
const std::map<AnfNodePtr, int> &apply_map) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
if (func_graph == nullptr || node == nullptr) {
|
if (func_graph == nullptr || node == nullptr) {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
|
@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters,
|
void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> ¶meters,
|
||||||
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map) {
|
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) {
|
||||||
bool first_flag = true;
|
bool first_flag = true;
|
||||||
for (const AnfNodePtr& param : parameters) {
|
for (const AnfNodePtr ¶m : parameters) {
|
||||||
if (first_flag) {
|
if (first_flag) {
|
||||||
first_flag = false;
|
first_flag = false;
|
||||||
ofs << " ";
|
ofs << " ";
|
||||||
|
@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& node) {
|
void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// output type of each input argument
|
// output type of each input argument
|
||||||
auto& inputs = node->inputs();
|
auto &inputs = node->inputs();
|
||||||
if (inputs.size() > 1) {
|
if (inputs.size() > 1) {
|
||||||
ofs << " #(";
|
ofs << " #(";
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod
|
||||||
ofs << " #scope: " << node->scope()->name();
|
ofs << " #scope: " << node->scope()->name();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes,
|
void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
|
||||||
const FuncGraphPtr& func_graph) {
|
const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int idx = 1;
|
int idx = 1;
|
||||||
std::map<AnfNodePtr, int> apply_map;
|
std::map<AnfNodePtr, int> apply_map;
|
||||||
for (const AnfNodePtr& node : nodes) {
|
for (const AnfNodePtr &node : nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
auto& inputs = cnode->inputs();
|
auto &inputs = cnode->inputs();
|
||||||
std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
|
std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
|
||||||
// non-return node
|
// non-return node
|
||||||
if (node != func_graph->get_return()) {
|
if (node != func_graph->get_return()) {
|
||||||
|
@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) {
|
void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun
|
||||||
ofs << "}\n";
|
ofs << "}\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) {
|
void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt
|
||||||
ofs.close();
|
ofs.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
|
void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
|
||||||
if (graphs.empty()) {
|
if (graphs.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector
|
||||||
|
|
||||||
param_index = 1;
|
param_index = 1;
|
||||||
|
|
||||||
for (const auto& tagged_graph : graphs) {
|
for (const auto &tagged_graph : graphs) {
|
||||||
tagged_cnodes_ = tagged_graph.second;
|
tagged_cnodes_ = tagged_graph.second;
|
||||||
ExportOneFuncGraph(ofs, tagged_graph.first);
|
ExportOneFuncGraph(ofs, tagged_graph.first);
|
||||||
tagged_cnodes_.clear();
|
tagged_cnodes_.clear();
|
||||||
|
@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_DUMP_IR
|
#ifdef ENABLE_DUMP_IR
|
||||||
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) {
|
void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap
|
||||||
ChangeFileMode(filename, S_IRUSR);
|
ChangeFileMode(filename, S_IRUSR);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
|
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
|
||||||
AnfExporter exporter("", false);
|
AnfExporter exporter("", false);
|
||||||
ChangeFileMode(filename, S_IRWXU);
|
ChangeFileMode(filename, S_IRWXU);
|
||||||
exporter.ExportFuncGraph(filename, graphs);
|
exporter.ExportFuncGraph(filename, graphs);
|
||||||
|
@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graph
|
||||||
ChangeFileMode(filename, S_IRUSR);
|
ChangeFileMode(filename, S_IRUSR);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) {
|
void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) {
|
||||||
static bool already_printed = false;
|
static bool already_printed = false;
|
||||||
if (already_printed) {
|
if (already_printed) {
|
||||||
return;
|
return;
|
||||||
|
@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) {
|
||||||
<< "please recompile source to enable it. See help of building script.";
|
<< "please recompile source to enable it. See help of building script.";
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
|
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
|
||||||
static bool already_printed = false;
|
static bool already_printed = false;
|
||||||
if (already_printed) {
|
if (already_printed) {
|
||||||
return;
|
return;
|
||||||
|
@ -732,7 +732,7 @@ enum Token : int {
|
||||||
TOK_ERROR // file read error
|
TOK_ERROR // file read error
|
||||||
};
|
};
|
||||||
|
|
||||||
std::map<Token, const char*> token_text = {
|
std::map<Token, const char *> token_text = {
|
||||||
{TOK_INVALID, "invalid"}, // invalid token
|
{TOK_INVALID, "invalid"}, // invalid token
|
||||||
{TOK_LPARENTHESIS, "("}, // ( left parenthesis
|
{TOK_LPARENTHESIS, "("}, // ( left parenthesis
|
||||||
{TOK_RPARENTHESIS, ")"}, // ) right parenthesis
|
{TOK_RPARENTHESIS, ")"}, // ) right parenthesis
|
||||||
|
@ -761,14 +761,14 @@ std::map<Token, const char*> token_text = {
|
||||||
class Lexer {
|
class Lexer {
|
||||||
public:
|
public:
|
||||||
// filename is checked in ImportIR;
|
// filename is checked in ImportIR;
|
||||||
explicit Lexer(const char* filename) : fin(filename) {}
|
explicit Lexer(const char *filename) : fin(filename) {}
|
||||||
|
|
||||||
~Lexer() {
|
~Lexer() {
|
||||||
try {
|
try {
|
||||||
if (fin.is_open()) {
|
if (fin.is_open()) {
|
||||||
fin.close();
|
fin.close();
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(ERROR) << "Exception when closing file";
|
MS_LOG(ERROR) << "Exception when closing file";
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
std::string exName(abi::__cxa_current_exception_type()->name());
|
std::string exName(abi::__cxa_current_exception_type()->name());
|
||||||
|
@ -776,7 +776,7 @@ class Lexer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsSingleCharToken(char ch, Token* token_ptr) {
|
bool IsSingleCharToken(char ch, Token *token_ptr) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
std::unordered_map<char, Token> char_to_token = {
|
std::unordered_map<char, Token> char_to_token = {
|
||||||
{'(', TOK_LPARENTHESIS},
|
{'(', TOK_LPARENTHESIS},
|
||||||
|
@ -806,7 +806,7 @@ class Lexer {
|
||||||
Token GetNextToken() {
|
Token GetNextToken() {
|
||||||
#ifdef DEBUG
|
#ifdef DEBUG
|
||||||
Token token = GetNextTokenInner();
|
Token token = GetNextTokenInner();
|
||||||
const char* str = token_text[token];
|
const char *str = token_text[token];
|
||||||
std::string text = (str == nullptr ? GetTokenText() : str);
|
std::string text = (str == nullptr ? GetTokenText() : str);
|
||||||
MS_LOG(DEBUG) << "------Parse token] " << text;
|
MS_LOG(DEBUG) << "------Parse token] " << text;
|
||||||
return token;
|
return token;
|
||||||
|
@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE;
|
||||||
|
|
||||||
class IrParser {
|
class IrParser {
|
||||||
public:
|
public:
|
||||||
explicit IrParser(const char* filename) : lexer_(filename) {}
|
explicit IrParser(const char *filename) : lexer_(filename) {}
|
||||||
|
|
||||||
~IrParser() {}
|
~IrParser() {}
|
||||||
|
|
||||||
py::object LoadObject(const std::string& file_name) const {
|
py::object LoadObject(const std::string &file_name) const {
|
||||||
std::string pkl_path = GetMsIrPath();
|
std::string pkl_path = GetMsIrPath();
|
||||||
py::object default_obj = load_obj(pkl_path + "/" + file_name);
|
py::object default_obj = load_obj(pkl_path + "/" + file_name);
|
||||||
return default_obj;
|
return default_obj;
|
||||||
|
@ -1087,7 +1087,7 @@ class IrParser {
|
||||||
MS_LOG(INFO) << "Total graphs: " << func_graphs_.size();
|
MS_LOG(INFO) << "Total graphs: " << func_graphs_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseParent(FuncGraphPtr* const parent_ptr) {
|
Token ParseParent(FuncGraphPtr *const parent_ptr) {
|
||||||
if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
|
if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -1168,7 +1168,7 @@ class IrParser {
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) {
|
FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) {
|
||||||
Token tok = lexer_.SkipWhiteToken();
|
Token tok = lexer_.SkipWhiteToken();
|
||||||
while (tok == TOK_VARIABLE) {
|
while (tok == TOK_VARIABLE) {
|
||||||
if (ParseStatement(func_graph) == nullptr) {
|
if (ParseStatement(func_graph) == nullptr) {
|
||||||
|
@ -1264,56 +1264,56 @@ class IrParser {
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const {
|
void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = dtype;
|
*ptr = dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetTupleType(TypePtr* ptr) {
|
void SetTupleType(TypePtr *ptr) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<Tuple>();
|
*ptr = std::make_shared<Tuple>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetTupleType(TypePtr* ptr, const TypePtrList& elems) {
|
void SetTupleType(TypePtr *ptr, const TypePtrList &elems) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<Tuple>(elems);
|
*ptr = std::make_shared<Tuple>(elems);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector<int>&) {
|
void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<TensorType>(elem_type);
|
*ptr = std::make_shared<TensorType>(elem_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetListType(TypePtr* ptr) {
|
void SetListType(TypePtr *ptr) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<List>();
|
*ptr = std::make_shared<List>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetListType(TypePtr* ptr, const TypePtrList& elems) {
|
void SetListType(TypePtr *ptr, const TypePtrList &elems) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<List>(elems);
|
*ptr = std::make_shared<List>(elems);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) {
|
void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<JTagged>(elem);
|
*ptr = std::make_shared<JTagged>(elem);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const {
|
void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1321,45 +1321,45 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
// void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {}
|
// void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {}
|
||||||
void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const {
|
void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<abstract::AbstractNone>();
|
*ptr = std::make_shared<abstract::AbstractNone>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {}
|
void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {}
|
||||||
void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {}
|
void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {}
|
||||||
|
|
||||||
void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) {
|
void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// if one of elems is nullptr, just return
|
// if one of elems is nullptr, just return
|
||||||
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) {
|
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<abstract::AbstractTuple>(elems);
|
*ptr = std::make_shared<abstract::AbstractTuple>(elems);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector<int>& shape) {
|
void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape);
|
*ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) {
|
void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) {
|
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*ptr = std::make_shared<abstract::AbstractList>(elems);
|
*ptr = std::make_shared<abstract::AbstractList>(elems);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) {
|
void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1367,7 +1367,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) {
|
Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) {
|
||||||
if (tok != TOK_LBRACKET) {
|
if (tok != TOK_LBRACKET) {
|
||||||
MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol.";
|
MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol.";
|
||||||
return tok;
|
return tok;
|
||||||
|
@ -1415,7 +1415,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) {
|
Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
|
||||||
if (tok != TOK_LPARENTHESIS) {
|
if (tok != TOK_LPARENTHESIS) {
|
||||||
if (ptr != nullptr) {
|
if (ptr != nullptr) {
|
||||||
SetBasicType(ptr, std::make_shared<TensorType>());
|
SetBasicType(ptr, std::make_shared<TensorType>());
|
||||||
|
@ -1454,7 +1454,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsNumberType(const std::string& type, TypeId* typeid_ptr) {
|
bool IsNumberType(const std::string &type, TypeId *typeid_ptr) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
static std::unordered_map<std::string, TypeId> basic_types = {
|
static std::unordered_map<std::string, TypeId> basic_types = {
|
||||||
{"Bool", kNumberTypeBool},
|
{"Bool", kNumberTypeBool},
|
||||||
|
@ -1486,7 +1486,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) {
|
void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) {
|
||||||
TypePtr dtype = nullptr;
|
TypePtr dtype = nullptr;
|
||||||
|
|
||||||
std::unordered_map<int, TypePtr> type_map = {
|
std::unordered_map<int, TypePtr> type_map = {
|
||||||
|
@ -1519,7 +1519,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) {
|
Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) {
|
||||||
if (type == "NoneType") {
|
if (type == "NoneType") {
|
||||||
SetBasicType(ptr, std::make_shared<TypeNone>());
|
SetBasicType(ptr, std::make_shared<TypeNone>());
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
|
@ -1541,7 +1541,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) {
|
Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
|
||||||
if (tok != TOK_IDENTIFIER) {
|
if (tok != TOK_IDENTIFIER) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -1588,11 +1588,11 @@ class IrParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) {
|
Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) {
|
||||||
return ParseOneType(func_graph, lexer_.GetNextToken(), abstract);
|
return ParseOneType(func_graph, lexer_.GetNextToken(), abstract);
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) {
|
Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
|
||||||
Token tok = ParseAttribute(func_graph, prim);
|
Token tok = ParseAttribute(func_graph, prim);
|
||||||
while (tok == TOK_COMMA) {
|
while (tok == TOK_COMMA) {
|
||||||
tok = ParseAttribute(func_graph, prim);
|
tok = ParseAttribute(func_graph, prim);
|
||||||
|
@ -1603,7 +1603,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) {
|
Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
|
||||||
Token tok = lexer_.GetNextToken();
|
Token tok = lexer_.GetNextToken();
|
||||||
if (tok != TOK_IDENTIFIER) {
|
if (tok != TOK_IDENTIFIER) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
|
@ -1670,7 +1670,7 @@ class IrParser {
|
||||||
return tok == TOK_RPARENTHESIS ? func_graph : nullptr;
|
return tok == TOK_RPARENTHESIS ? func_graph : nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr>* const inputs_ptr) {
|
FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
|
||||||
Token tok = ParseArgument(func_graph, inputs_ptr);
|
Token tok = ParseArgument(func_graph, inputs_ptr);
|
||||||
while (tok == TOK_COMMA) {
|
while (tok == TOK_COMMA) {
|
||||||
tok = ParseArgument(func_graph, inputs_ptr);
|
tok = ParseArgument(func_graph, inputs_ptr);
|
||||||
|
@ -1681,9 +1681,9 @@ class IrParser {
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) {
|
AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string ¶m_name) {
|
||||||
while (func_graph != nullptr) {
|
while (func_graph != nullptr) {
|
||||||
for (auto& ptr : func_graph->parameters()) {
|
for (auto &ptr : func_graph->parameters()) {
|
||||||
MS_EXCEPTION_IF_NULL(ptr);
|
MS_EXCEPTION_IF_NULL(ptr);
|
||||||
ParameterPtr param = ptr->cast<ParameterPtr>();
|
ParameterPtr param = ptr->cast<ParameterPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(param);
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
@ -1701,12 +1701,12 @@ class IrParser {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Match(const std::string& str, const std::string& pattern) const {
|
bool Match(const std::string &str, const std::string &pattern) const {
|
||||||
return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0;
|
return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename V>
|
template <typename T, typename V>
|
||||||
Token ParseScalar(ValuePtr* const val_ptr) {
|
Token ParseScalar(ValuePtr *const val_ptr) {
|
||||||
if (lexer_.GetNextToken() != TOK_NUMBER) {
|
if (lexer_.GetNextToken() != TOK_NUMBER) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -1725,7 +1725,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename VT, typename V, typename T>
|
template <typename VT, typename V, typename T>
|
||||||
Token ParseScalar(ValuePtr* const val_ptr, Token tok) {
|
Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
|
||||||
if (tok != TOK_LPARENTHESIS) {
|
if (tok != TOK_LPARENTHESIS) {
|
||||||
*val_ptr = std::make_shared<T>();
|
*val_ptr = std::make_shared<T>();
|
||||||
return tok;
|
return tok;
|
||||||
|
@ -1735,7 +1735,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename VT, typename V, typename T, const unsigned nbits>
|
template <typename VT, typename V, typename T, const unsigned nbits>
|
||||||
Token ParseScalar(ValuePtr* const val_ptr, Token tok) {
|
Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
|
||||||
if (tok != TOK_LPARENTHESIS) {
|
if (tok != TOK_LPARENTHESIS) {
|
||||||
*val_ptr = std::make_shared<T>(nbits);
|
*val_ptr = std::make_shared<T>(nbits);
|
||||||
return tok;
|
return tok;
|
||||||
|
@ -1745,7 +1745,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T StringToScalar(const std::string& text) {
|
T StringToScalar(const std::string &text) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
T value;
|
T value;
|
||||||
ss << text;
|
ss << text;
|
||||||
|
@ -1753,7 +1753,7 @@ class IrParser {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseTensor(ValuePtr* const val_ptr) {
|
Token ParseTensor(ValuePtr *const val_ptr) {
|
||||||
// parse type
|
// parse type
|
||||||
TypeId type;
|
TypeId type;
|
||||||
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
|
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
|
||||||
|
@ -1803,7 +1803,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParsePrimType(Token tok, PrimType* prim_type_ptr) {
|
Token ParsePrimType(Token tok, PrimType *prim_type_ptr) {
|
||||||
if (tok != TOK_LBRACE) {
|
if (tok != TOK_LBRACE) {
|
||||||
return tok;
|
return tok;
|
||||||
}
|
}
|
||||||
|
@ -1830,7 +1830,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) {
|
Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
|
||||||
if (tok != TOK_LPARENTHESIS) {
|
if (tok != TOK_LPARENTHESIS) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -1855,7 +1855,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) {
|
Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
|
||||||
if (tok != TOK_LBRACE) {
|
if (tok != TOK_LBRACE) {
|
||||||
return tok;
|
return tok;
|
||||||
}
|
}
|
||||||
|
@ -1868,7 +1868,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseBoolValue(const std::string& key, bool* val_ptr) {
|
Token ParseBoolValue(const std::string &key, bool *val_ptr) {
|
||||||
if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) {
|
if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -1892,7 +1892,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) {
|
Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) {
|
||||||
if (lexer_.GetNextToken() != TOK_LBRACE) {
|
if (lexer_.GetNextToken() != TOK_LBRACE) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -1920,7 +1920,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) {
|
Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) {
|
||||||
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
|
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -1951,7 +1951,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) {
|
Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) {
|
||||||
if (lexer_.GetNextToken() != TOK_AT_FILE) {
|
if (lexer_.GetNextToken() != TOK_AT_FILE) {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -1984,7 +1984,7 @@ class IrParser {
|
||||||
return next;
|
return next;
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) {
|
Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) {
|
||||||
if (Match(id, "MultitypeFuncGraph::")) {
|
if (Match(id, "MultitypeFuncGraph::")) {
|
||||||
std::string name = id.substr(strlen("MultitypeFuncGraph::"));
|
std::string name = id.substr(strlen("MultitypeFuncGraph::"));
|
||||||
auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name);
|
auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name);
|
||||||
|
@ -2024,8 +2024,8 @@ class IrParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr,
|
Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr,
|
||||||
AnfNodePtr* const node_ptr = nullptr) {
|
AnfNodePtr *const node_ptr = nullptr) {
|
||||||
if (id == "None") {
|
if (id == "None") {
|
||||||
*val_ptr = std::make_shared<None>();
|
*val_ptr = std::make_shared<None>();
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
|
@ -2075,9 +2075,9 @@ class IrParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid,
|
Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid,
|
||||||
const std::vector<ValuePtr>& elems, const std::vector<AnfNodePtr>& nodes,
|
const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes,
|
||||||
ValuePtr* const val_ptr, AnfNodePtr* node_ptr) {
|
ValuePtr *const val_ptr, AnfNodePtr *node_ptr) {
|
||||||
if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) {
|
if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) {
|
||||||
if (node_is_valid && node_ptr != nullptr) {
|
if (node_is_valid && node_ptr != nullptr) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
@ -2097,8 +2097,8 @@ class IrParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr,
|
Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr,
|
||||||
AnfNodePtr* node_ptr = nullptr) {
|
AnfNodePtr *node_ptr = nullptr) {
|
||||||
Token left_tok = tok;
|
Token left_tok = tok;
|
||||||
|
|
||||||
std::vector<ValuePtr> elems;
|
std::vector<ValuePtr> elems;
|
||||||
|
@ -2138,7 +2138,7 @@ class IrParser {
|
||||||
return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr);
|
return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) {
|
Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) {
|
||||||
// tuple or list
|
// tuple or list
|
||||||
if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) {
|
if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) {
|
||||||
return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr);
|
return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr);
|
||||||
|
@ -2152,7 +2152,7 @@ class IrParser {
|
||||||
return TOK_ERROR;
|
return TOK_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr,
|
Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr,
|
||||||
Token tok = TOK_INVALID) {
|
Token tok = TOK_INVALID) {
|
||||||
if (tok == TOK_INVALID) {
|
if (tok == TOK_INVALID) {
|
||||||
tok = lexer_.GetNextToken();
|
tok = lexer_.GetNextToken();
|
||||||
|
@ -2193,7 +2193,7 @@ class IrParser {
|
||||||
return lexer_.GetNextToken();
|
return lexer_.GetNextToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
Token ParseArgument(const FuncGraphPtr& func_graph, std::vector<AnfNodePtr>* const inputs_ptr) {
|
Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
|
||||||
Token tok = lexer_.GetNextToken();
|
Token tok = lexer_.GetNextToken();
|
||||||
if (tok == TOK_RPARENTHESIS) {
|
if (tok == TOK_RPARENTHESIS) {
|
||||||
return tok;
|
return tok;
|
||||||
|
@ -2208,7 +2208,7 @@ class IrParser {
|
||||||
return tok;
|
return tok;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<FuncGraphPtr>& GetFuncGraphs() const { return func_graphs_; }
|
const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Lexer lexer_;
|
Lexer lexer_;
|
||||||
|
@ -2226,14 +2226,14 @@ class IrParser {
|
||||||
std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter
|
std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<FuncGraphPtr> ImportIR(const std::string& filename) {
|
std::vector<FuncGraphPtr> ImportIR(const std::string &filename) {
|
||||||
IrParser parser(filename.c_str());
|
IrParser parser(filename.c_str());
|
||||||
parser.ParseFile();
|
parser.ParseFile();
|
||||||
return parser.GetFuncGraphs();
|
return parser.GetFuncGraphs();
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_DUMP_IR
|
#ifdef ENABLE_DUMP_IR
|
||||||
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
|
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Func graph is nullptr";
|
MS_LOG(ERROR) << "Func graph is nullptr";
|
||||||
return;
|
return;
|
||||||
|
@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
char real_path[PATH_MAX] = {0};
|
char real_path[PATH_MAX] = {0};
|
||||||
char* real_path_ret = nullptr;
|
char *real_path_ret = nullptr;
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX);
|
real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX);
|
||||||
#else
|
#else
|
||||||
|
@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
|
||||||
ChangeFileMode(file_path, S_IRUSR);
|
ChangeFileMode(file_path, S_IRUSR);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
void DumpIRProto(const FuncGraphPtr&, const std::string&) {
|
void DumpIRProto(const FuncGraphPtr &, const std::string &) {
|
||||||
static bool already_printed = false;
|
static bool already_printed = false;
|
||||||
if (already_printed) {
|
if (already_printed) {
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -39,7 +39,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
struct ParamPtrEqual {
|
struct ParamPtrEqual {
|
||||||
bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const {
|
bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const {
|
||||||
const ParameterPtr param1 = dyn_cast<Parameter>(t1);
|
const ParameterPtr param1 = dyn_cast<Parameter>(t1);
|
||||||
const ParameterPtr param2 = dyn_cast<Parameter>(t2);
|
const ParameterPtr param2 = dyn_cast<Parameter>(t2);
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ struct ParamPtrEqual {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ParamPtrHasher {
|
struct ParamPtrHasher {
|
||||||
std::size_t operator()(AnfNodePtr const& param) const {
|
std::size_t operator()(AnfNodePtr const ¶m) const {
|
||||||
const ParameterPtr parameter = dyn_cast<Parameter>(param);
|
const ParameterPtr parameter = dyn_cast<Parameter>(param);
|
||||||
if (parameter == nullptr) {
|
if (parameter == nullptr) {
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -64,39 +64,39 @@ struct ParamPtrHasher {
|
||||||
|
|
||||||
class AnfExporter {
|
class AnfExporter {
|
||||||
public:
|
public:
|
||||||
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
|
explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false)
|
||||||
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
|
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
|
||||||
func_graph_set.clear();
|
func_graph_set.clear();
|
||||||
exported.clear();
|
exported.clear();
|
||||||
}
|
}
|
||||||
virtual ~AnfExporter() {}
|
virtual ~AnfExporter() {}
|
||||||
|
|
||||||
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
|
void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
|
||||||
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
|
void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual std::string GetNodeType(const AnfNodePtr& nd);
|
virtual std::string GetNodeType(const AnfNodePtr &nd);
|
||||||
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
|
int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true);
|
||||||
int GetParamIndexFromExported(const AnfNodePtr& param);
|
int GetParamIndexFromExported(const AnfNodePtr ¶m);
|
||||||
std::string DumpObject(const py::object& obj, const std::string& category) const;
|
std::string DumpObject(const py::object &obj, const std::string &category) const;
|
||||||
std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node);
|
std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node);
|
||||||
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph);
|
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph);
|
||||||
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst);
|
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst);
|
||||||
std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value);
|
std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value);
|
||||||
std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value);
|
std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
|
||||||
std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value);
|
std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
|
||||||
std::string GetPrimitiveText(const PrimitivePtr& prim);
|
std::string GetPrimitiveText(const PrimitivePtr &prim);
|
||||||
std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value);
|
std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value);
|
||||||
std::string GetNameSpaceText(const parse::NameSpacePtr& ns);
|
std::string GetNameSpaceText(const parse::NameSpacePtr &ns);
|
||||||
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph);
|
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph);
|
||||||
std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
|
std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const std::map<AnfNodePtr, int>& apply_map);
|
const std::map<AnfNodePtr, int> &apply_map);
|
||||||
void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph);
|
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
|
||||||
void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters,
|
void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> ¶meters,
|
||||||
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map);
|
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map);
|
||||||
|
|
||||||
void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node);
|
void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
|
||||||
void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph);
|
void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
int param_index;
|
int param_index;
|
||||||
OrderedSet<FuncGraphPtr> func_graph_set{};
|
OrderedSet<FuncGraphPtr> func_graph_set{};
|
||||||
|
@ -108,16 +108,16 @@ class AnfExporter {
|
||||||
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
|
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
|
void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph);
|
||||||
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs);
|
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs);
|
||||||
|
|
||||||
std::vector<FuncGraphPtr> ImportIR(const std::string& filename);
|
std::vector<FuncGraphPtr> ImportIR(const std::string &filename);
|
||||||
|
|
||||||
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
|
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
|
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
|
||||||
|
|
||||||
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
|
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
|
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace draw {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Only for ValueNode
|
// Only for ValueNode
|
||||||
std::string ValueType(const ValueNodePtr& node) {
|
std::string ValueType(const ValueNodePtr &node) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) {
|
||||||
return v->type_name();
|
return v->type_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ReplaceSpecialChar(const std::string& str) {
|
std::string ReplaceSpecialChar(const std::string &str) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
for (size_t i = 0; i < str.size(); i++) {
|
for (size_t i = 0; i < str.size(); i++) {
|
||||||
if (str[i] == '<') {
|
if (str[i] == '<') {
|
||||||
|
@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// API of debug utils
|
// API of debug utils
|
||||||
void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs,
|
void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs,
|
||||||
bool is_user) {
|
bool is_user) {
|
||||||
if (sub_graphs == nullptr) {
|
if (sub_graphs == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (auto& nd : nodes) {
|
for (auto &nd : nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(nd);
|
MS_EXCEPTION_IF_NULL(nd);
|
||||||
auto sub_graph = nd->func_graph();
|
auto sub_graph = nd->func_graph();
|
||||||
if (sub_graph != nullptr) {
|
if (sub_graph != nullptr) {
|
||||||
|
@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
|
void DrawValueNodes(const std::vector<AnfNodePtr> &nodes,
|
||||||
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) {
|
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) {
|
||||||
if (sub_graphs == nullptr) {
|
if (sub_graphs == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int dup_idx = 0;
|
int dup_idx = 0;
|
||||||
|
|
||||||
for (auto& nd : nodes) {
|
for (auto &nd : nodes) {
|
||||||
for (auto& t : SuccIncoming(nd)) {
|
for (auto &t : SuccIncoming(nd)) {
|
||||||
MS_EXCEPTION_IF_NULL(t);
|
MS_EXCEPTION_IF_NULL(t);
|
||||||
MS_EXCEPTION_IF_NULL(nd);
|
MS_EXCEPTION_IF_NULL(nd);
|
||||||
if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
|
if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
|
||||||
|
@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) {
|
void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) {
|
||||||
if (digraph == nullptr) {
|
if (digraph == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
|
||||||
}
|
}
|
||||||
|
|
||||||
// Draw edge
|
// Draw edge
|
||||||
for (auto& nd : nodes) {
|
for (auto &nd : nodes) {
|
||||||
auto succs = SuccIncoming(nd);
|
auto succs = SuccIncoming(nd);
|
||||||
auto num = succs.size();
|
auto num = succs.size();
|
||||||
for (size_t i = 0; i < num; i++) {
|
for (size_t i = 0; i < num; i++) {
|
||||||
auto& t = succs.at(i);
|
auto &t = succs.at(i);
|
||||||
MS_EXCEPTION_IF_NULL(t);
|
MS_EXCEPTION_IF_NULL(t);
|
||||||
if (t->isa<ValueNode>() || t->isa<Parameter>()) {
|
if (t->isa<ValueNode>() || t->isa<Parameter>()) {
|
||||||
if ((!is_user) || (i != 0)) {
|
if ((!is_user) || (i != 0)) {
|
||||||
|
@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) {
|
void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
|
||||||
DrawValueNodes(nodes, &sub_graphs);
|
DrawValueNodes(nodes, &sub_graphs);
|
||||||
|
|
||||||
// Draw subgraph
|
// Draw subgraph
|
||||||
for (const auto& gsub : sub_graphs) {
|
for (const auto &gsub : sub_graphs) {
|
||||||
digraph->SubGraph(gsub.first, gsub.second);
|
digraph->SubGraph(gsub.first, gsub.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_DUMP_IR
|
#ifdef ENABLE_DUMP_IR
|
||||||
void Draw(const std::string& filename, const FuncGraphPtr& func_graph) {
|
void Draw(const std::string &filename, const FuncGraphPtr &func_graph) {
|
||||||
const std::string dot_suffix = ".dot";
|
const std::string dot_suffix = ".dot";
|
||||||
std::string filename_with_suffix =
|
std::string filename_with_suffix =
|
||||||
(filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
|
(filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
|
||||||
DrawByOpt(filename_with_suffix, func_graph, false);
|
DrawByOpt(filename_with_suffix, func_graph, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) {
|
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
|
||||||
DrawByOpt(filename, func_graph, true);
|
DrawByOpt(filename, func_graph, true);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
void Draw(const std::string&, const FuncGraphPtr&) {
|
void Draw(const std::string &, const FuncGraphPtr &) {
|
||||||
static bool already_printed = false;
|
static bool already_printed = false;
|
||||||
if (already_printed) {
|
if (already_printed) {
|
||||||
return;
|
return;
|
||||||
|
@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) {
|
||||||
<< "please recompile source to enable it. See help of building script.";
|
<< "please recompile source to enable it. See help of building script.";
|
||||||
}
|
}
|
||||||
|
|
||||||
void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) {
|
void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) {
|
||||||
static bool already_printed = false;
|
static bool already_printed = false;
|
||||||
if (already_printed) {
|
if (already_printed) {
|
||||||
return;
|
return;
|
||||||
|
@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) {
|
||||||
return "plaintext";
|
return "plaintext";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Graphviz::Color(const AnfNodePtr& node) {
|
std::string Graphviz::Color(const AnfNodePtr &node) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -259,7 +259,7 @@ void BaseDigraph::Start() {
|
||||||
buffer_ << "compound=true" << std::endl;
|
buffer_ << "compound=true" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BaseDigraph::Head(const AnfNodePtr& node, int id) {
|
void BaseDigraph::Head(const AnfNodePtr &node, int id) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
|
void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
|
||||||
buffer_ << ":" << idx;
|
buffer_ << ":" << idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BaseDigraph::Tail(const FuncGraphPtr& func_graph) {
|
void BaseDigraph::Tail(const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -304,12 +304,12 @@ void BaseDigraph::End() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
|
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
|
||||||
buffer_ << "parameters_" << key << "[shape=plaintext ";
|
buffer_ << "parameters_" << key << "[shape=plaintext ";
|
||||||
buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
|
buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
|
||||||
buffer_ << "<tr><td>parameters</td></tr>";
|
buffer_ << "<tr><td>parameters</td></tr>";
|
||||||
int count = 0;
|
int count = 0;
|
||||||
for (auto& parameter : key->parameters()) {
|
for (auto ¶meter : key->parameters()) {
|
||||||
buffer_ << "<tr><td>";
|
buffer_ << "<tr><td>";
|
||||||
buffer_ << parameter->ToString();
|
buffer_ << parameter->ToString();
|
||||||
auto py_p = dyn_cast<Parameter>(parameter)->default_param();
|
auto py_p = dyn_cast<Parameter>(parameter)->default_param();
|
||||||
|
@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
|
||||||
buffer_ << "</table>>,];";
|
buffer_ << "</table>>,];";
|
||||||
}
|
}
|
||||||
|
|
||||||
void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) {
|
void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) {
|
||||||
if (key == nullptr || gsub == nullptr) {
|
if (key == nullptr || gsub == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -361,12 +361,12 @@ Digraph::~Digraph() {
|
||||||
if (fout_.is_open()) {
|
if (fout_.is_open()) {
|
||||||
fout_.close();
|
fout_.close();
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(ERROR) << "Exception when closing file " << filename_;
|
MS_LOG(ERROR) << "Exception when closing file " << filename_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) {
|
static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) {
|
||||||
size_t start_pos = 0;
|
size_t start_pos = 0;
|
||||||
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
||||||
(void)str.replace(start_pos, from.length(), to);
|
(void)str.replace(start_pos, from.length(), to);
|
||||||
|
@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st
|
||||||
return str;
|
return str;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
|
static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(graph_obj);
|
MS_EXCEPTION_IF_NULL(graph_obj);
|
||||||
graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
|
graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
|
||||||
<< "'>";
|
<< "'>";
|
||||||
|
@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
|
||||||
graph_obj->buffer() << "</td></tr>";
|
graph_obj->buffer() << "</td></tr>";
|
||||||
graph_obj->buffer() << "<tr><td align='left'>";
|
graph_obj->buffer() << "<tr><td align='left'>";
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto& attr : attrs) {
|
for (const auto &attr : attrs) {
|
||||||
if (i != 0) {
|
if (i != 0) {
|
||||||
graph_obj->buffer() << "<br/>";
|
graph_obj->buffer() << "<br/>";
|
||||||
}
|
}
|
||||||
|
@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
|
||||||
graph_obj->buffer() << "</table>>,";
|
graph_obj->buffer() << "</table>>,";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
|
static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
|
||||||
if (graph_obj == nullptr || node == nullptr) {
|
if (graph_obj == nullptr || node == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
|
static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) {
|
||||||
if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
|
if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
|
||||||
}
|
}
|
||||||
graph_obj->buffer() << ">";
|
graph_obj->buffer() << ">";
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (auto& attr : attrs) {
|
for (auto &attr : attrs) {
|
||||||
if (i != 0) {
|
if (i != 0) {
|
||||||
graph_obj->buffer() << "<br/>";
|
graph_obj->buffer() << "<br/>";
|
||||||
}
|
}
|
||||||
|
@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() {
|
||||||
if (fout_.is_open()) {
|
if (fout_.is_open()) {
|
||||||
fout_.close();
|
fout_.close();
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(ERROR) << "exception when closing file " << filename_;
|
MS_LOG(ERROR) << "exception when closing file " << filename_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,9 +31,9 @@ namespace parse = mindspore::parse;
|
||||||
|
|
||||||
class Graphviz {
|
class Graphviz {
|
||||||
public:
|
public:
|
||||||
Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {}
|
Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {}
|
||||||
|
|
||||||
explicit Graphviz(const std::string& name) : name_(name) {}
|
explicit Graphviz(const std::string &name) : name_(name) {}
|
||||||
|
|
||||||
virtual ~Graphviz() {}
|
virtual ~Graphviz() {}
|
||||||
|
|
||||||
|
@ -41,8 +41,8 @@ class Graphviz {
|
||||||
virtual void End() {}
|
virtual void End() {}
|
||||||
|
|
||||||
virtual std::string Shape(AnfNodePtr node);
|
virtual std::string Shape(AnfNodePtr node);
|
||||||
std::string Color(const AnfNodePtr& node);
|
std::string Color(const AnfNodePtr &node);
|
||||||
std::ostringstream& buffer() { return buffer_; }
|
std::ostringstream &buffer() { return buffer_; }
|
||||||
std::ostringstream buffer_;
|
std::ostringstream buffer_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -53,8 +53,8 @@ class Graphviz {
|
||||||
|
|
||||||
class BaseDigraph : public Graphviz {
|
class BaseDigraph : public Graphviz {
|
||||||
public:
|
public:
|
||||||
BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {}
|
BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {}
|
||||||
explicit BaseDigraph(const std::string& name) : Graphviz(name) {}
|
explicit BaseDigraph(const std::string &name) : Graphviz(name) {}
|
||||||
~BaseDigraph() override = default;
|
~BaseDigraph() override = default;
|
||||||
|
|
||||||
virtual void Node(AnfNodePtr node, int id = 0) = 0;
|
virtual void Node(AnfNodePtr node, int id = 0) = 0;
|
||||||
|
@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz {
|
||||||
void Start() override;
|
void Start() override;
|
||||||
void End() override;
|
void End() override;
|
||||||
virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start);
|
virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start);
|
||||||
void FuncGraphParameters(const FuncGraphPtr& key);
|
void FuncGraphParameters(const FuncGraphPtr &key);
|
||||||
void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub);
|
void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub);
|
||||||
|
|
||||||
const std::string& name() const { return name_; }
|
const std::string &name() const { return name_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void Head(const AnfNodePtr& node, int id = 0);
|
void Head(const AnfNodePtr &node, int id = 0);
|
||||||
void Tail(const AnfNodePtr& node, int idx, int id = 0);
|
void Tail(const AnfNodePtr &node, int idx, int id = 0);
|
||||||
void Tail(const FuncGraphPtr& func_graph);
|
void Tail(const FuncGraphPtr &func_graph);
|
||||||
};
|
};
|
||||||
|
|
||||||
class Digraph : public BaseDigraph {
|
class Digraph : public BaseDigraph {
|
||||||
public:
|
public:
|
||||||
Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {}
|
Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
|
||||||
explicit Digraph(const std::string& name) : BaseDigraph(name) {}
|
explicit Digraph(const std::string &name) : BaseDigraph(name) {}
|
||||||
~Digraph() override;
|
~Digraph() override;
|
||||||
|
|
||||||
void Node(AnfNodePtr node, int id = 0) override;
|
void Node(AnfNodePtr node, int id = 0) override;
|
||||||
|
@ -86,8 +86,8 @@ class Digraph : public BaseDigraph {
|
||||||
|
|
||||||
class ModelDigraph : public BaseDigraph {
|
class ModelDigraph : public BaseDigraph {
|
||||||
public:
|
public:
|
||||||
ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {}
|
ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
|
||||||
explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {}
|
explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {}
|
||||||
~ModelDigraph() override;
|
~ModelDigraph() override;
|
||||||
|
|
||||||
std::string Shape(AnfNodePtr node) override;
|
std::string Shape(AnfNodePtr node) override;
|
||||||
|
@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph {
|
||||||
};
|
};
|
||||||
|
|
||||||
// API to draw
|
// API to draw
|
||||||
void Draw(const std::string& filename, const FuncGraphPtr& func_graph);
|
void Draw(const std::string &filename, const FuncGraphPtr &func_graph);
|
||||||
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
|
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
} // namespace draw
|
} // namespace draw
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -33,38 +33,38 @@ class ProtoExporter {
|
||||||
ProtoExporter() {}
|
ProtoExporter() {}
|
||||||
~ProtoExporter() {}
|
~ProtoExporter() {}
|
||||||
|
|
||||||
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
|
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitModelInfo();
|
void InitModelInfo();
|
||||||
void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto);
|
void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto);
|
||||||
std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
|
std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const std::map<AnfNodePtr, size_t>& apply_map,
|
const std::map<AnfNodePtr, size_t> &apply_map,
|
||||||
std::map<AnfNodePtr, size_t>* const_map_ptr);
|
std::map<AnfNodePtr, size_t> *const_map_ptr);
|
||||||
void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto);
|
void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
|
||||||
void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto);
|
void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto);
|
||||||
void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto);
|
void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto);
|
||||||
void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto);
|
void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto);
|
||||||
void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto);
|
void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto);
|
||||||
void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto);
|
void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto);
|
||||||
|
|
||||||
void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto);
|
void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
|
||||||
void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto);
|
void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
|
||||||
void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto,
|
void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
|
||||||
std::map<AnfNodePtr, size_t>* const_map_ptr);
|
std::map<AnfNodePtr, size_t> *const_map_ptr);
|
||||||
void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* apply_map_ptr,
|
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
|
||||||
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto);
|
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto);
|
||||||
void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node,
|
void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
|
||||||
const std::map<AnfNodePtr, size_t>& apply_map, std::map<AnfNodePtr, size_t>* const_map_ptr,
|
const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
|
||||||
irpb::GraphProto* graph_proto);
|
irpb::GraphProto *graph_proto);
|
||||||
void ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto);
|
void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto);
|
||||||
|
|
||||||
static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
|
static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
|
||||||
|
|
||||||
irpb::ModelProto model_;
|
irpb::ModelProto model_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static irpb::DataType GetNumberDataType(const TypePtr& type) {
|
static irpb::DataType GetNumberDataType(const TypePtr &type) {
|
||||||
switch (type->type_id()) {
|
switch (type->type_id()) {
|
||||||
case kNumberTypeBool:
|
case kNumberTypeBool:
|
||||||
return irpb::DT_BOOL;
|
return irpb::DT_BOOL;
|
||||||
|
@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) {
|
void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
|
||||||
if (type_proto == nullptr) {
|
if (type_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
|
||||||
type_proto->set_data_type(irpb::DT_TENSOR);
|
type_proto->set_data_type(irpb::DT_TENSOR);
|
||||||
if (shape != nullptr && shape->isa<abstract::Shape>()) {
|
if (shape != nullptr && shape->isa<abstract::Shape>()) {
|
||||||
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
|
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
|
||||||
for (const auto& elem : shape_info->shape()) {
|
for (const auto &elem : shape_info->shape()) {
|
||||||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (type->isa<Tuple>()) {
|
} else if (type->isa<Tuple>()) {
|
||||||
TuplePtr tuple_type = dyn_cast<Tuple>(type);
|
TuplePtr tuple_type = dyn_cast<Tuple>(type);
|
||||||
type_proto->set_data_type(irpb::DT_TUPLE);
|
type_proto->set_data_type(irpb::DT_TUPLE);
|
||||||
for (const auto& elem_type : tuple_type->elements()) {
|
for (const auto &elem_type : tuple_type->elements()) {
|
||||||
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
|
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
|
||||||
}
|
}
|
||||||
} else if (type->isa<TypeType>()) {
|
} else if (type->isa<TypeType>()) {
|
||||||
|
@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
|
||||||
} else if (type->isa<List>()) {
|
} else if (type->isa<List>()) {
|
||||||
ListPtr list_type = dyn_cast<List>(type);
|
ListPtr list_type = dyn_cast<List>(type);
|
||||||
type_proto->set_data_type(irpb::DT_LIST);
|
type_proto->set_data_type(irpb::DT_LIST);
|
||||||
for (const auto& elem_type : list_type->elements()) {
|
for (const auto &elem_type : list_type->elements()) {
|
||||||
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
|
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
|
||||||
}
|
}
|
||||||
} else if (type->isa<TypeAnything>()) {
|
} else if (type->isa<TypeAnything>()) {
|
||||||
|
@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) {
|
void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) {
|
||||||
if (node == nullptr || type_proto == nullptr) {
|
if (node == nullptr || type_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SetNodeOutputType(node->Type(), node->Shape(), type_proto);
|
SetNodeOutputType(node->Type(), node->Shape(), type_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) {
|
void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
|
||||||
if (val == nullptr || value_proto == nullptr) {
|
if (val == nullptr || value_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (val->isa<StringImm>()) {
|
if (val->isa<StringImm>()) {
|
||||||
const StringImmPtr& value = dyn_cast<StringImm>(val);
|
const StringImmPtr &value = dyn_cast<StringImm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_STRING);
|
value_proto->set_dtype(irpb::DT_STRING);
|
||||||
value_proto->set_str_val(value->value());
|
value_proto->set_str_val(value->value());
|
||||||
} else if (val->isa<Scalar>()) {
|
} else if (val->isa<Scalar>()) {
|
||||||
|
@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
|
||||||
} else if (val->isa<tensor::Tensor>()) {
|
} else if (val->isa<tensor::Tensor>()) {
|
||||||
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
|
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
|
||||||
value_proto->set_dtype(irpb::DT_TENSOR);
|
value_proto->set_dtype(irpb::DT_TENSOR);
|
||||||
irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val();
|
irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
|
||||||
tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype()));
|
tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype()));
|
||||||
for (auto& elem : tensor_ptr->shape()) {
|
for (auto &elem : tensor_ptr->shape()) {
|
||||||
tensor_proto->add_dims(elem);
|
tensor_proto->add_dims(elem);
|
||||||
}
|
}
|
||||||
} else if (val->isa<TensorType>()) {
|
} else if (val->isa<TensorType>()) {
|
||||||
value_proto->set_dtype(irpb::DT_TYPE);
|
value_proto->set_dtype(irpb::DT_TYPE);
|
||||||
|
|
||||||
irpb::TypeProto* type_proto = value_proto->mutable_type_val();
|
irpb::TypeProto *type_proto = value_proto->mutable_type_val();
|
||||||
type_proto->set_data_type(irpb::DT_TENSOR);
|
type_proto->set_data_type(irpb::DT_TENSOR);
|
||||||
TypePtr elem_type = dyn_cast<TensorType>(val)->element();
|
TypePtr elem_type = dyn_cast<TensorType>(val)->element();
|
||||||
type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
|
type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
|
||||||
|
@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) {
|
void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) {
|
||||||
if (val == nullptr || value_proto == nullptr) {
|
if (val == nullptr || value_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (val->isa<BoolImm>()) {
|
if (val->isa<BoolImm>()) {
|
||||||
const BoolImmPtr& value = dyn_cast<BoolImm>(val);
|
const BoolImmPtr &value = dyn_cast<BoolImm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_BOOL);
|
value_proto->set_dtype(irpb::DT_BOOL);
|
||||||
value_proto->set_bool_val(value->value());
|
value_proto->set_bool_val(value->value());
|
||||||
} else if (val->isa<Int8Imm>()) {
|
} else if (val->isa<Int8Imm>()) {
|
||||||
const Int8ImmPtr& value = dyn_cast<Int8Imm>(val);
|
const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_INT8);
|
value_proto->set_dtype(irpb::DT_INT8);
|
||||||
value_proto->set_int_val(value->value());
|
value_proto->set_int_val(value->value());
|
||||||
} else if (val->isa<Int16Imm>()) {
|
} else if (val->isa<Int16Imm>()) {
|
||||||
const Int16ImmPtr& value = dyn_cast<Int16Imm>(val);
|
const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_INT16);
|
value_proto->set_dtype(irpb::DT_INT16);
|
||||||
value_proto->set_int_val(value->value());
|
value_proto->set_int_val(value->value());
|
||||||
} else if (val->isa<Int32Imm>()) {
|
} else if (val->isa<Int32Imm>()) {
|
||||||
const Int32ImmPtr& value = dyn_cast<Int32Imm>(val);
|
const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_INT32);
|
value_proto->set_dtype(irpb::DT_INT32);
|
||||||
value_proto->set_int_val(value->value());
|
value_proto->set_int_val(value->value());
|
||||||
} else if (val->isa<Int64Imm>()) {
|
} else if (val->isa<Int64Imm>()) {
|
||||||
const Int64ImmPtr& value = dyn_cast<Int64Imm>(val);
|
const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_INT64);
|
value_proto->set_dtype(irpb::DT_INT64);
|
||||||
value_proto->set_int_val(value->value());
|
value_proto->set_int_val(value->value());
|
||||||
} else if (val->isa<UInt8Imm>()) {
|
} else if (val->isa<UInt8Imm>()) {
|
||||||
const UInt8ImmPtr& value = dyn_cast<UInt8Imm>(val);
|
const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_UINT8);
|
value_proto->set_dtype(irpb::DT_UINT8);
|
||||||
value_proto->set_uint_val(value->value());
|
value_proto->set_uint_val(value->value());
|
||||||
} else if (val->isa<UInt16Imm>()) {
|
} else if (val->isa<UInt16Imm>()) {
|
||||||
const UInt16ImmPtr& value = dyn_cast<UInt16Imm>(val);
|
const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_UINT16);
|
value_proto->set_dtype(irpb::DT_UINT16);
|
||||||
value_proto->set_uint_val(value->value());
|
value_proto->set_uint_val(value->value());
|
||||||
} else if (val->isa<UInt32Imm>()) {
|
} else if (val->isa<UInt32Imm>()) {
|
||||||
const UInt32ImmPtr& value = dyn_cast<UInt32Imm>(val);
|
const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_UINT32);
|
value_proto->set_dtype(irpb::DT_UINT32);
|
||||||
value_proto->set_uint_val(value->value());
|
value_proto->set_uint_val(value->value());
|
||||||
} else if (val->isa<UInt64Imm>()) {
|
} else if (val->isa<UInt64Imm>()) {
|
||||||
const UInt64ImmPtr& value = dyn_cast<UInt64Imm>(val);
|
const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_UINT64);
|
value_proto->set_dtype(irpb::DT_UINT64);
|
||||||
value_proto->set_uint_val(value->value());
|
value_proto->set_uint_val(value->value());
|
||||||
} else if (val->isa<FP32Imm>()) {
|
} else if (val->isa<FP32Imm>()) {
|
||||||
const FP32ImmPtr& value = dyn_cast<FP32Imm>(val);
|
const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_FLOAT32);
|
value_proto->set_dtype(irpb::DT_FLOAT32);
|
||||||
value_proto->set_float_val(value->value());
|
value_proto->set_float_val(value->value());
|
||||||
} else if (val->isa<FP64Imm>()) {
|
} else if (val->isa<FP64Imm>()) {
|
||||||
const FP64ImmPtr& value = dyn_cast<FP64Imm>(val);
|
const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
|
||||||
value_proto->set_dtype(irpb::DT_FLOAT64);
|
value_proto->set_dtype(irpb::DT_FLOAT64);
|
||||||
value_proto->set_double_val(value->value());
|
value_proto->set_double_val(value->value());
|
||||||
} else {
|
} else {
|
||||||
|
@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) {
|
void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) {
|
||||||
if (val == nullptr || value_proto == nullptr) {
|
if (val == nullptr || value_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (val->isa<ValueTuple>()) {
|
if (val->isa<ValueTuple>()) {
|
||||||
const ValueTuplePtr& value = dyn_cast<ValueTuple>(val);
|
const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
|
||||||
value_proto->set_dtype(irpb::DT_TUPLE);
|
value_proto->set_dtype(irpb::DT_TUPLE);
|
||||||
for (const auto& item : value->value()) {
|
for (const auto &item : value->value()) {
|
||||||
SetValueToProto(item, value_proto->add_values());
|
SetValueToProto(item, value_proto->add_values());
|
||||||
}
|
}
|
||||||
} else if (val->isa<ValueList>()) {
|
} else if (val->isa<ValueList>()) {
|
||||||
const ValueListPtr& value = dyn_cast<ValueList>(val);
|
const ValueListPtr &value = dyn_cast<ValueList>(val);
|
||||||
value_proto->set_dtype(irpb::DT_LIST);
|
value_proto->set_dtype(irpb::DT_LIST);
|
||||||
for (const auto& item : value->value()) {
|
for (const auto &item : value->value()) {
|
||||||
SetValueToProto(item, value_proto->add_values());
|
SetValueToProto(item, value_proto->add_values());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) {
|
void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) {
|
||||||
if (val == nullptr || value_proto == nullptr) {
|
if (val == nullptr || value_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
value_proto->set_dtype(irpb::DT_DICT);
|
value_proto->set_dtype(irpb::DT_DICT);
|
||||||
for (const auto& item : val->value()) {
|
for (const auto &item : val->value()) {
|
||||||
irpb::NamedValueProto* named_val = value_proto->add_dict_val();
|
irpb::NamedValueProto *named_val = value_proto->add_dict_val();
|
||||||
named_val->set_key(item.first);
|
named_val->set_key(item.first);
|
||||||
SetValueToProto(item.second, named_val->mutable_value());
|
SetValueToProto(item.second, named_val->mutable_value());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) {
|
void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) {
|
||||||
if (node == nullptr || node_proto == nullptr) {
|
if (node == nullptr || node_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr&
|
||||||
MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
|
MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
const PrimitivePtr& prim = GetValueNode<PrimitivePtr>(node);
|
const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
|
||||||
node_proto->set_op_type(prim->name());
|
node_proto->set_op_type(prim->name());
|
||||||
for (const auto& attr : prim->attrs()) {
|
for (const auto &attr : prim->attrs()) {
|
||||||
irpb::AttributeProto* attr_proto = node_proto->add_attribute();
|
irpb::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
attr_proto->set_name(attr.first);
|
attr_proto->set_name(attr.first);
|
||||||
SetValueToProto(attr.second, attr_proto->mutable_value());
|
SetValueToProto(attr.second, attr_proto->mutable_value());
|
||||||
}
|
}
|
||||||
node_proto->set_scope(node->scope()->name());
|
node_proto->set_scope(node->scope()->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node,
|
std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
|
||||||
const std::map<AnfNodePtr, size_t>& apply_map,
|
const std::map<AnfNodePtr, size_t> &apply_map,
|
||||||
std::map<AnfNodePtr, size_t>* const_map_ptr) {
|
std::map<AnfNodePtr, size_t> *const_map_ptr) {
|
||||||
if (node == nullptr || const_map_ptr == nullptr) {
|
if (node == nullptr || const_map_ptr == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt
|
||||||
MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
|
MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) {
|
std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
InitModelInfo();
|
InitModelInfo();
|
||||||
irpb::GraphProto* graph_proto = model_.mutable_graph();
|
irpb::GraphProto *graph_proto = model_.mutable_graph();
|
||||||
ExportFuncGraph(func_graph, graph_proto);
|
ExportFuncGraph(func_graph, graph_proto);
|
||||||
return model_.SerializeAsString();
|
return model_.SerializeAsString();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) {
|
void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
|
||||||
if (func_graph == nullptr || graph_proto == nullptr) {
|
if (func_graph == nullptr || graph_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP
|
||||||
ExportValueNodes(const_map, graph_proto);
|
ExportValueNodes(const_map, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) {
|
void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
|
||||||
if (func_graph == nullptr || graph_proto == nullptr) {
|
if (func_graph == nullptr || graph_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> parameters = func_graph->parameters();
|
std::vector<AnfNodePtr> parameters = func_graph->parameters();
|
||||||
for (auto& param : parameters) {
|
for (auto ¶m : parameters) {
|
||||||
irpb::ParameterProto* param_proto = graph_proto->add_parameters();
|
irpb::ParameterProto *param_proto = graph_proto->add_parameters();
|
||||||
param_proto->set_name(param->ToString());
|
param_proto->set_name(param->ToString());
|
||||||
|
|
||||||
SetNodeOutputType(param, param_proto->mutable_type());
|
SetNodeOutputType(param, param_proto->mutable_type());
|
||||||
|
@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto,
|
void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
|
||||||
std::map<AnfNodePtr, size_t>* const_map_ptr) {
|
std::map<AnfNodePtr, size_t> *const_map_ptr) {
|
||||||
if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
|
if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// topo sort nodes
|
// topo sort nodes
|
||||||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||||
std::map<AnfNodePtr, size_t> apply_map;
|
std::map<AnfNodePtr, size_t> apply_map;
|
||||||
for (const AnfNodePtr& node : nodes) {
|
for (const AnfNodePtr &node : nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node,
|
void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* apply_map_ptr,
|
std::map<AnfNodePtr, size_t> *apply_map_ptr,
|
||||||
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) {
|
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
|
||||||
if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
|
if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
|
||||||
graph_proto == nullptr) {
|
graph_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
|
||||||
auto apply_idx = apply_map_ptr->size() + 1;
|
auto apply_idx = apply_map_ptr->size() + 1;
|
||||||
(*apply_map_ptr)[node] = apply_idx;
|
(*apply_map_ptr)[node] = apply_idx;
|
||||||
|
|
||||||
auto& inputs = node->inputs();
|
auto &inputs = node->inputs();
|
||||||
if (inputs.size() < 1) {
|
if (inputs.size() < 1) {
|
||||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||||
}
|
}
|
||||||
AnfNodePtr op = inputs[0];
|
AnfNodePtr op = inputs[0];
|
||||||
irpb::NodeProto* node_proto = graph_proto->add_node();
|
irpb::NodeProto *node_proto = graph_proto->add_node();
|
||||||
|
|
||||||
// CNode/ConstGraph/Const/Parameter
|
// CNode/ConstGraph/Const/Parameter
|
||||||
if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
|
if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
|
||||||
|
@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
|
||||||
|
|
||||||
// process OP inputs
|
// process OP inputs
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
irpb::InputProto* input_proto = node_proto->add_input();
|
irpb::InputProto *input_proto = node_proto->add_input();
|
||||||
input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE);
|
input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE);
|
||||||
std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
|
std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
|
||||||
input_proto->set_name(id);
|
input_proto->set_name(id);
|
||||||
|
@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node,
|
void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
|
||||||
const std::map<AnfNodePtr, size_t>& apply_map,
|
const std::map<AnfNodePtr, size_t> &apply_map,
|
||||||
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) {
|
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
|
||||||
if (ret_node == nullptr || !ret_node->isa<CNode>()) {
|
if (ret_node == nullptr || !ret_node->isa<CNode>()) {
|
||||||
MS_LOG(EXCEPTION) << "Graph return node is illegal";
|
MS_LOG(EXCEPTION) << "Graph return node is illegal";
|
||||||
}
|
}
|
||||||
|
@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const
|
||||||
if (graph_proto == nullptr) {
|
if (graph_proto == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "graph_proto is nullptr";
|
MS_LOG(EXCEPTION) << "graph_proto is nullptr";
|
||||||
}
|
}
|
||||||
irpb::OutputProto* output_proto = graph_proto->add_outputs();
|
irpb::OutputProto *output_proto = graph_proto->add_outputs();
|
||||||
if (output_proto == nullptr) {
|
if (output_proto == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "output_proto is nullptr";
|
MS_LOG(EXCEPTION) << "output_proto is nullptr";
|
||||||
}
|
}
|
||||||
|
@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const
|
||||||
SetNodeOutputType(arg, output_proto->mutable_type());
|
SetNodeOutputType(arg, output_proto->mutable_type());
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool CompareValue(const std::pair<AnfNodePtr, size_t>& x, const std::pair<AnfNodePtr, size_t>& y) {
|
static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
|
||||||
return x.second < y.second;
|
return x.second < y.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto) {
|
void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) {
|
||||||
std::vector<std::pair<AnfNodePtr, size_t>> nodes;
|
std::vector<std::pair<AnfNodePtr, size_t>> nodes;
|
||||||
(void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
|
(void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
|
||||||
[](const std::pair<AnfNodePtr, size_t>& item) { return item; });
|
[](const std::pair<AnfNodePtr, size_t> &item) { return item; });
|
||||||
|
|
||||||
sort(nodes.begin(), nodes.end(), CompareValue);
|
sort(nodes.begin(), nodes.end(), CompareValue);
|
||||||
|
|
||||||
for (auto& item : nodes) {
|
for (auto &item : nodes) {
|
||||||
if (graph_proto == nullptr) {
|
if (graph_proto == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "graph_proto is nullptr";
|
MS_LOG(EXCEPTION) << "graph_proto is nullptr";
|
||||||
}
|
}
|
||||||
irpb::NamedValueProto* named_value = graph_proto->add_const_vals();
|
irpb::NamedValueProto *named_value = graph_proto->add_const_vals();
|
||||||
MS_EXCEPTION_IF_NULL(named_value);
|
MS_EXCEPTION_IF_NULL(named_value);
|
||||||
named_value->set_key(GetConstNodeId(item.second));
|
named_value->set_key(GetConstNodeId(item.second));
|
||||||
SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
|
SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
|
||||||
|
@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_m
|
||||||
|
|
||||||
void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); }
|
void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); }
|
||||||
|
|
||||||
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) {
|
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
|
||||||
ProtoExporter exporter;
|
ProtoExporter exporter;
|
||||||
return exporter.GetFuncGraphProtoString(func_graph);
|
return exporter.GetFuncGraphProtoString(func_graph);
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ Dump::Dump()
|
||||||
dump_iter_(0),
|
dump_iter_(0),
|
||||||
cur_iter_(0) {}
|
cur_iter_(0) {}
|
||||||
|
|
||||||
bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
|
bool Dump::IsKernelNeedDump(const std::string &kernel_name) {
|
||||||
if (dump_mode_ == 0) {
|
if (dump_mode_ == 0) {
|
||||||
// Dump All Kernels mode
|
// Dump All Kernels mode
|
||||||
return true;
|
return true;
|
||||||
|
@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
|
bool Dump::ParseDumpConfig(const std::string &dump_config_file) {
|
||||||
std::ifstream jsonFile(dump_config_file);
|
std::ifstream jsonFile(dump_config_file);
|
||||||
if (!jsonFile.is_open()) {
|
if (!jsonFile.is_open()) {
|
||||||
MS_LOG(ERROR) << dump_config_file << " open failed.";
|
MS_LOG(ERROR) << dump_config_file << " open failed.";
|
||||||
|
@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
|
bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) {
|
||||||
if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() ||
|
if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() ||
|
||||||
dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() ||
|
dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() ||
|
||||||
dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() ||
|
dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() ||
|
||||||
|
@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
|
bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
|
||||||
auto trans_flag = dumpSettings.at("trans_flag");
|
auto trans_flag = dumpSettings.at("trans_flag");
|
||||||
auto enable = dumpSettings.at("enable");
|
auto enable = dumpSettings.at("enable");
|
||||||
auto mode = dumpSettings.at("mode");
|
auto mode = dumpSettings.at("mode");
|
||||||
|
@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
|
||||||
dump_path_ = path;
|
dump_path_ = path;
|
||||||
dump_net_name_ = net_name;
|
dump_net_name_ = net_name;
|
||||||
dump_iter_ = iteration;
|
dump_iter_ = iteration;
|
||||||
for (const auto& kernel : kernels) {
|
for (const auto &kernel : kernels) {
|
||||||
dump_kernels_.push_back(kernel);
|
dump_kernels_.push_back(kernel);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Dump::SetDumpConfFromJsonFile() {
|
bool Dump::SetDumpConfFromJsonFile() {
|
||||||
const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
|
const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
|
||||||
if (config_path_str != nullptr) {
|
if (config_path_str != nullptr) {
|
||||||
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
|
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
|
||||||
} else {
|
} else {
|
||||||
|
@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() {
|
||||||
return ParseDumpConfig(dump_config_file);
|
return ParseDumpConfig(dump_config_file);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) {
|
bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) {
|
||||||
if (filename.empty() || data == nullptr || len == 0) {
|
if (filename.empty() || data == nullptr || len == 0) {
|
||||||
MS_LOG(ERROR) << "Incorrect parameter.";
|
MS_LOG(ERROR) << "Incorrect parameter.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len)
|
||||||
MS_LOG(ERROR) << "Open file " << realpath << " fail.";
|
MS_LOG(ERROR) << "Open file " << realpath << " fail.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
(void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len));
|
(void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len));
|
||||||
fd.close();
|
fd.close();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
|
bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) {
|
||||||
MS_EXCEPTION_IF_NULL(outpath);
|
MS_EXCEPTION_IF_NULL(outpath);
|
||||||
auto path_split_pos = inpath.find_last_of('/');
|
auto path_split_pos = inpath.find_last_of('/');
|
||||||
if (path_split_pos == std::string::npos) {
|
if (path_split_pos == std::string::npos) {
|
||||||
|
@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Dump::CreateNotExistDirs(const std::string& path) {
|
bool Dump::CreateNotExistDirs(const std::string &path) {
|
||||||
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
|
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
|
||||||
MS_EXCEPTION_IF_NULL(fs);
|
MS_EXCEPTION_IF_NULL(fs);
|
||||||
char temp_path[PATH_MAX] = {0};
|
char temp_path[PATH_MAX] = {0};
|
||||||
|
|
|
@ -43,11 +43,11 @@ class Dump {
|
||||||
|
|
||||||
uint32_t cur_iter() const { return cur_iter_; }
|
uint32_t cur_iter() const { return cur_iter_; }
|
||||||
|
|
||||||
bool IsKernelNeedDump(const std::string& kernel_name);
|
bool IsKernelNeedDump(const std::string &kernel_name);
|
||||||
|
|
||||||
bool SetDumpConfFromJsonFile();
|
bool SetDumpConfFromJsonFile();
|
||||||
|
|
||||||
static bool DumpToFile(const std::string& filename, const void* data, size_t len);
|
static bool DumpToFile(const std::string &filename, const void *data, size_t len);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool dump_enable_;
|
bool dump_enable_;
|
||||||
|
@ -59,14 +59,14 @@ class Dump {
|
||||||
uint32_t cur_iter_;
|
uint32_t cur_iter_;
|
||||||
std::vector<std::string> dump_kernels_;
|
std::vector<std::string> dump_kernels_;
|
||||||
|
|
||||||
static bool GetRealPath(const std::string& inpath, std::string* outpath);
|
static bool GetRealPath(const std::string &inpath, std::string *outpath);
|
||||||
|
|
||||||
static bool CreateNotExistDirs(const std::string& path);
|
static bool CreateNotExistDirs(const std::string &path);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool ParseDumpConfig(const std::string& dump_config_file);
|
bool ParseDumpConfig(const std::string &dump_config_file);
|
||||||
bool IsConfigExist(const nlohmann::json& dumpSettings);
|
bool IsConfigExist(const nlohmann::json &dumpSettings);
|
||||||
bool IsConfigValid(const nlohmann::json& dumpSettings);
|
bool IsConfigValid(const nlohmann::json &dumpSettings);
|
||||||
};
|
};
|
||||||
|
|
||||||
using DumpConfPtr = std::shared_ptr<Dump>;
|
using DumpConfPtr = std::shared_ptr<Dump>;
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
#include "pipeline/parse/python_adapter.h"
|
#include "pipeline/parse/python_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) {
|
std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {
|
||||||
std::string temp_line = line;
|
std::string temp_line = line;
|
||||||
if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) &&
|
if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) &&
|
||||||
tip != kSourceLineTipDiscard) {
|
tip != kSourceLineTipDiscard) {
|
||||||
|
@ -101,14 +101,14 @@ DebugInfo::DebugInfo() {
|
||||||
name_ = "";
|
name_ = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
DebugInfo::DebugInfo(const std::string& name) {
|
DebugInfo::DebugInfo(const std::string &name) {
|
||||||
InitValueFromContext();
|
InitValueFromContext();
|
||||||
unique_id_ = gen_unique_id();
|
unique_id_ = gen_unique_id();
|
||||||
debug_id_ = -1;
|
debug_id_ = -1;
|
||||||
name_ = name;
|
name_ = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
DebugInfo::DebugInfo(const LocationPtr& loc) {
|
DebugInfo::DebugInfo(const LocationPtr &loc) {
|
||||||
InitValueFromContext();
|
InitValueFromContext();
|
||||||
unique_id_ = gen_unique_id();
|
unique_id_ = gen_unique_id();
|
||||||
debug_id_ = -1;
|
debug_id_ = -1;
|
||||||
|
@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t DebugInfo::unique_id_through_copy() const {
|
int64_t DebugInfo::unique_id_through_copy() const {
|
||||||
TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info();
|
TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info();
|
||||||
if (trace_info != nullptr) {
|
if (trace_info != nullptr) {
|
||||||
if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) {
|
if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) {
|
||||||
return trace_info->debug_info()->unique_id_through_copy();
|
return trace_info->debug_info()->unique_id_through_copy();
|
||||||
|
@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() {
|
||||||
}
|
}
|
||||||
return DebugInfo::location();
|
return DebugInfo::location();
|
||||||
}
|
}
|
||||||
void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; }
|
void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; }
|
||||||
|
|
||||||
TraceContextPtr TraceManager::CurrentContextInfo() {
|
TraceContextPtr TraceManager::CurrentContextInfo() {
|
||||||
if (!TraceManager::trace_context_stack_.empty()) {
|
if (!TraceManager::trace_context_stack_.empty()) {
|
||||||
|
@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) {
|
void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) {
|
||||||
TraceContextPtr context = std::make_shared<TraceContext>(location);
|
TraceContextPtr context = std::make_shared<TraceContext>(location);
|
||||||
context->set_func_name(func_name);
|
context->set_func_name(func_name);
|
||||||
TraceManager::trace_context_stack_.push(context);
|
TraceManager::trace_context_stack_.push(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceManager::DebugTrace(const LocationPtr& location) {
|
void TraceManager::DebugTrace(const LocationPtr &location) {
|
||||||
TraceContextPtr context = std::make_shared<TraceContext>(location);
|
TraceContextPtr context = std::make_shared<TraceContext>(location);
|
||||||
TraceManager::trace_context_stack_.push(context);
|
TraceManager::trace_context_stack_.push(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
|
void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) {
|
||||||
if (trace_info == nullptr) {
|
if (trace_info == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
|
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
|
||||||
}
|
}
|
||||||
|
@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
|
||||||
TraceManager::trace_context_stack_.push(context);
|
TraceManager::trace_context_stack_.push(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) {
|
void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) {
|
||||||
if (trace_info == nullptr) {
|
if (trace_info == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
|
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou
|
||||||
// Location class record the location in source code.
|
// Location class record the location in source code.
|
||||||
class Location {
|
class Location {
|
||||||
public:
|
public:
|
||||||
Location(const std::string& file_name, int line, int column, int line_end, int column_end)
|
Location(const std::string &file_name, int line, int column, int line_end, int column_end)
|
||||||
: file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {}
|
: file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {}
|
||||||
Location(const Location& loc)
|
Location(const Location &loc)
|
||||||
: file_name_(loc.file_name_),
|
: file_name_(loc.file_name_),
|
||||||
line_(loc.line_),
|
line_(loc.line_),
|
||||||
column_(loc.column_),
|
column_(loc.column_),
|
||||||
|
@ -77,21 +77,21 @@ class TraceManager {
|
||||||
TraceManager() = default;
|
TraceManager() = default;
|
||||||
~TraceManager() = default;
|
~TraceManager() = default;
|
||||||
static TraceContextPtr CurrentContextInfo();
|
static TraceContextPtr CurrentContextInfo();
|
||||||
static void DebugTrace(const std::string& func_name, const LocationPtr& location);
|
static void DebugTrace(const std::string &func_name, const LocationPtr &location);
|
||||||
static void DebugTrace(const LocationPtr& location);
|
static void DebugTrace(const LocationPtr &location);
|
||||||
static void DebugTrace(const TraceInfoPtr& trace_info);
|
static void DebugTrace(const TraceInfoPtr &trace_info);
|
||||||
// debug trace with a cloned trace info with debug_info
|
// debug trace with a cloned trace info with debug_info
|
||||||
static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info);
|
static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info);
|
||||||
static void EndTrace();
|
static void EndTrace();
|
||||||
static std::stack<TraceContextPtr> trace_context_stack_;
|
static std::stack<TraceContextPtr> trace_context_stack_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TraceGuard {
|
class TraceGuard {
|
||||||
public:
|
public:
|
||||||
explicit TraceGuard(const std::string func_name, const LocationPtr& location) {
|
explicit TraceGuard(const std::string func_name, const LocationPtr &location) {
|
||||||
TraceManager::DebugTrace(func_name, location);
|
TraceManager::DebugTrace(func_name, location);
|
||||||
}
|
}
|
||||||
explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); }
|
explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); }
|
||||||
~TraceGuard() { TraceManager::EndTrace(); }
|
~TraceGuard() { TraceManager::EndTrace(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -106,23 +106,23 @@ class TraceContext {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
~TraceContext() = default;
|
~TraceContext() = default;
|
||||||
explicit TraceContext(const LocationPtr& loc) {
|
explicit TraceContext(const LocationPtr &loc) {
|
||||||
ProcessAttributeFromContext();
|
ProcessAttributeFromContext();
|
||||||
location_ = loc;
|
location_ = loc;
|
||||||
}
|
}
|
||||||
explicit TraceContext(const std::string& func_name) {
|
explicit TraceContext(const std::string &func_name) {
|
||||||
ProcessAttributeFromContext();
|
ProcessAttributeFromContext();
|
||||||
func_name_ = func_name;
|
func_name_ = func_name;
|
||||||
}
|
}
|
||||||
explicit TraceContext(const TraceInfoPtr& trace_info) {
|
explicit TraceContext(const TraceInfoPtr &trace_info) {
|
||||||
ProcessAttributeFromContext();
|
ProcessAttributeFromContext();
|
||||||
trace_info_ = trace_info;
|
trace_info_ = trace_info;
|
||||||
}
|
}
|
||||||
void set_location(const LocationPtr& loc) { location_ = loc; }
|
void set_location(const LocationPtr &loc) { location_ = loc; }
|
||||||
LocationPtr location() { return location_; }
|
LocationPtr location() { return location_; }
|
||||||
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; }
|
void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
|
||||||
TraceInfoPtr trace_info() { return trace_info_; }
|
TraceInfoPtr trace_info() { return trace_info_; }
|
||||||
void set_func_name(const std::string& func_name) { func_name_ = func_name; }
|
void set_func_name(const std::string &func_name) { func_name_ = func_name; }
|
||||||
std::string func_name() { return func_name_; }
|
std::string func_name() { return func_name_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -130,9 +130,9 @@ class DebugInfo : public Base {
|
||||||
public:
|
public:
|
||||||
DebugInfo();
|
DebugInfo();
|
||||||
|
|
||||||
explicit DebugInfo(const std::string& name);
|
explicit DebugInfo(const std::string &name);
|
||||||
|
|
||||||
explicit DebugInfo(const LocationPtr& loc);
|
explicit DebugInfo(const LocationPtr &loc);
|
||||||
|
|
||||||
virtual ~DebugInfo() = default;
|
virtual ~DebugInfo() = default;
|
||||||
MS_DECLARE_PARENT(DebugInfo, Base);
|
MS_DECLARE_PARENT(DebugInfo, Base);
|
||||||
|
@ -141,12 +141,12 @@ class DebugInfo : public Base {
|
||||||
int64_t unique_id_through_copy() const;
|
int64_t unique_id_through_copy() const;
|
||||||
std::string get_id() { return std::to_string(debug_id()); }
|
std::string get_id() { return std::to_string(debug_id()); }
|
||||||
|
|
||||||
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; }
|
void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
|
||||||
TraceInfoPtr trace_info() { return trace_info_; }
|
TraceInfoPtr trace_info() { return trace_info_; }
|
||||||
void set_location(const LocationPtr& loc) { location_ = loc; }
|
void set_location(const LocationPtr &loc) { location_ = loc; }
|
||||||
virtual LocationPtr location() { return location_; }
|
virtual LocationPtr location() { return location_; }
|
||||||
std::string name() { return name_; }
|
std::string name() { return name_; }
|
||||||
void set_name(const std::string& name) { name_ = name; }
|
void set_name(const std::string &name) { name_ = name; }
|
||||||
virtual std::string debug_name();
|
virtual std::string debug_name();
|
||||||
|
|
||||||
virtual std::string get_python_func_belonged() { return ""; }
|
virtual std::string get_python_func_belonged() { return ""; }
|
||||||
|
@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo {
|
||||||
py_func_belonged_ = context_info->func_name();
|
py_func_belonged_ = context_info->func_name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) {
|
explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) {
|
||||||
if (TraceManager::CurrentContextInfo() != nullptr) {
|
if (TraceManager::CurrentContextInfo() != nullptr) {
|
||||||
auto context_info = TraceManager::CurrentContextInfo();
|
auto context_info = TraceManager::CurrentContextInfo();
|
||||||
py_func_belonged_ = context_info->func_name();
|
py_func_belonged_ = context_info->func_name();
|
||||||
|
@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo {
|
||||||
~NodeDebugInfo() override = default;
|
~NodeDebugInfo() override = default;
|
||||||
|
|
||||||
std::string debug_name() override;
|
std::string debug_name() override;
|
||||||
void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); }
|
void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); }
|
||||||
std::shared_ptr<AnfNode> get_node() const { return node_.lock(); }
|
std::shared_ptr<AnfNode> get_node() const { return node_.lock(); }
|
||||||
void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; }
|
void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; }
|
||||||
std::string get_python_func_belonged() override { return py_func_belonged_; }
|
std::string get_python_func_belonged() override { return py_func_belonged_; }
|
||||||
AnfNodeWeakPtr node_;
|
AnfNodeWeakPtr node_;
|
||||||
std::string py_func_belonged_;
|
std::string py_func_belonged_;
|
||||||
|
@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) {
|
explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) {
|
||||||
if (TraceManager::CurrentContextInfo() != nullptr) {
|
if (TraceManager::CurrentContextInfo() != nullptr) {
|
||||||
auto context_info = TraceManager::CurrentContextInfo();
|
auto context_info = TraceManager::CurrentContextInfo();
|
||||||
py_func_name_ = context_info->func_name();
|
py_func_name_ = context_info->func_name();
|
||||||
|
@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo {
|
||||||
std::string debug_name() override;
|
std::string debug_name() override;
|
||||||
LocationPtr location() override;
|
LocationPtr location() override;
|
||||||
LocationPtr deco_location() { return deco_loc_; }
|
LocationPtr deco_location() { return deco_loc_; }
|
||||||
void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
|
void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
|
||||||
FuncGraphPtr get_graph() const { return func_graph_.lock(); }
|
FuncGraphPtr get_graph() const { return func_graph_.lock(); }
|
||||||
void set_full_name(const std::string& name) { full_name_ = name; }
|
void set_full_name(const std::string &name) { full_name_ = name; }
|
||||||
std::string get_full_name() { return full_name_; }
|
std::string get_full_name() { return full_name_; }
|
||||||
void set_deco_location(const LocationPtr& deco_list_loc);
|
void set_deco_location(const LocationPtr &deco_list_loc);
|
||||||
std::string get_python_func_belonged() override { return py_func_name_; }
|
std::string get_python_func_belonged() override { return py_func_name_; }
|
||||||
FuncGraphWeakPtr func_graph_;
|
FuncGraphWeakPtr func_graph_;
|
||||||
LocationPtr deco_loc_;
|
LocationPtr deco_loc_;
|
||||||
|
|
|
@ -31,7 +31,7 @@ struct NameWithTrace {
|
||||||
std::string name;
|
std::string name;
|
||||||
std::vector<std::string> trace_labels;
|
std::vector<std::string> trace_labels;
|
||||||
};
|
};
|
||||||
static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) {
|
static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) {
|
||||||
switch (trace_label) {
|
switch (trace_label) {
|
||||||
case TraceLabelType::kShortSymbol:
|
case TraceLabelType::kShortSymbol:
|
||||||
return trace_info->symbol();
|
return trace_info->symbol();
|
||||||
|
@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
|
NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
|
||||||
NameWithTrace trace_name;
|
NameWithTrace trace_name;
|
||||||
// find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node
|
// find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node
|
||||||
auto temp_info = debug_info;
|
auto temp_info = debug_info;
|
||||||
|
@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe
|
||||||
return trace_name;
|
return trace_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) {
|
std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) {
|
||||||
std::string tags = "";
|
std::string tags = "";
|
||||||
for (auto& itr : trace_labels) {
|
for (auto &itr : trace_labels) {
|
||||||
std::string symbol = itr;
|
std::string symbol = itr;
|
||||||
tags = tags + symbol;
|
tags = tags + symbol;
|
||||||
}
|
}
|
||||||
|
@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the label name of the node debug info
|
// get the label name of the node debug info
|
||||||
std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
|
std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
|
||||||
NameWithTrace trace_name = RootName(debug_info, trace_label);
|
NameWithTrace trace_name = RootName(debug_info, trace_label);
|
||||||
return CombineTraceTypes(trace_name.name, trace_name.trace_labels);
|
return CombineTraceTypes(trace_name.name, trace_name.trace_labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
|
std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
|
||||||
auto temp_info = debug_info;
|
auto temp_info = debug_info;
|
||||||
std::string label = "";
|
std::string label = "";
|
||||||
while (temp_info != nullptr) {
|
while (temp_info != nullptr) {
|
||||||
|
@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// get trace with unique id chain
|
// get trace with unique id chain
|
||||||
std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); }
|
std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
|
||||||
|
|
||||||
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
|
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
|
||||||
if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
|
if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
|
||||||
return LabelStringUnique(debug_info);
|
return LabelStringUnique(debug_info);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ namespace label_manage {
|
||||||
enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId };
|
enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId };
|
||||||
TraceLabelType GetGlobalTraceLabelType();
|
TraceLabelType GetGlobalTraceLabelType();
|
||||||
void SetGlobalTraceLabelType(TraceLabelType label_type);
|
void SetGlobalTraceLabelType(TraceLabelType label_type);
|
||||||
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
|
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
|
||||||
} // namespace label_manage
|
} // namespace label_manage
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// namespace to support debug trace infomation
|
// namespace to support debug trace infomation
|
||||||
namespace trace {
|
namespace trace {
|
||||||
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) {
|
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) {
|
||||||
if (abs == nullptr) {
|
if (abs == nullptr) {
|
||||||
return "Null Abstract";
|
return "Null Abstract";
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) {
|
||||||
return debug_with_loc_vec;
|
return debug_with_loc_vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
|
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) {
|
||||||
auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info);
|
auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info);
|
||||||
if (debug_with_loc_vec.size() > 0) {
|
if (debug_with_loc_vec.size() > 0) {
|
||||||
return debug_with_loc_vec[0];
|
return debug_with_loc_vec[0];
|
||||||
|
@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
|
std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
|
||||||
if (info == nullptr) {
|
if (info == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
|
||||||
|
|
||||||
// a trace info identifies a node transform, so we can trace the node transform through
|
// a trace info identifies a node transform, so we can trace the node transform through
|
||||||
// a link of trace info and debug info
|
// a link of trace info and debug info
|
||||||
std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) {
|
std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) {
|
||||||
if (info_vec.size() < 1) {
|
if (info_vec.size() < 1) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL
|
||||||
return traced_info;
|
return traced_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
|
std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
|
||||||
if (info == nullptr) {
|
if (info == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) {
|
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
if (info == nullptr) {
|
if (info == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
|
@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) {
|
std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "graph:" << graph->ToString() << " with args[";
|
oss << "graph:" << graph->ToString() << " with args[";
|
||||||
auto params = graph->parameters();
|
auto params = graph->parameters();
|
||||||
|
@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DumpInferStack(std::ostringstream& oss) {
|
void DumpInferStack(std::ostringstream &oss) {
|
||||||
auto& infer_stack = GetCurrenGraphInferStack();
|
auto &infer_stack = GetCurrenGraphInferStack();
|
||||||
if (infer_stack.empty()) {
|
if (infer_stack.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) {
|
||||||
}
|
}
|
||||||
std::reverse(infer_vec.begin(), infer_vec.end());
|
std::reverse(infer_vec.begin(), infer_vec.end());
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (auto& item : infer_vec) {
|
for (auto &item : infer_vec) {
|
||||||
auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first);
|
auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first);
|
||||||
if (graph_infer == nullptr) {
|
if (graph_infer == nullptr) {
|
||||||
MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator";
|
MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator";
|
||||||
|
@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceGraphInfer() {
|
void TraceGraphInfer() {
|
||||||
auto& infer_stack = GetCurrenGraphInferStack();
|
auto &infer_stack = GetCurrenGraphInferStack();
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
if (infer_stack.empty()) {
|
if (infer_stack.empty()) {
|
||||||
return;
|
return;
|
||||||
|
@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
|
||||||
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
|
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
|
||||||
~AnalyzedFuncGraphExporter() override = default;
|
~AnalyzedFuncGraphExporter() override = default;
|
||||||
|
|
||||||
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
|
void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string GetNodeType(const AnfNodePtr& nd) override;
|
std::string GetNodeType(const AnfNodePtr &nd) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
|
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
|
||||||
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
|
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
|
||||||
auto& list = GetCNodeDebugStack();
|
auto &list = GetCNodeDebugStack();
|
||||||
for (size_t i = 0; i < list.size(); ++i) {
|
for (size_t i = 0; i < list.size(); ++i) {
|
||||||
auto node_cfg = list[i];
|
auto node_cfg = list[i];
|
||||||
auto fg = node_cfg->context()->func_graph();
|
auto fg = node_cfg->context()->func_graph();
|
||||||
|
@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() {
|
||||||
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
|
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
|
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
|
||||||
if (node_cfg_ == nullptr) {
|
if (node_cfg_ == nullptr) {
|
||||||
return AnfExporter::GetNodeType(node);
|
return AnfExporter::GetNodeType(node);
|
||||||
}
|
}
|
||||||
|
@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
|
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
|
||||||
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
|
const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) {
|
||||||
if (node_cfgs.empty()) {
|
if (node_cfgs.empty()) {
|
||||||
MS_LOG(DEBUG) << "Node configs is empty";
|
MS_LOG(DEBUG) << "Node configs is empty";
|
||||||
return;
|
return;
|
||||||
|
@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
|
||||||
auto tagged_func_graphs = CalcTaggedFuncGraphs();
|
auto tagged_func_graphs = CalcTaggedFuncGraphs();
|
||||||
|
|
||||||
// first output graph on the analysis stack
|
// first output graph on the analysis stack
|
||||||
for (const auto& node_cfg : node_cfgs) {
|
for (const auto &node_cfg : node_cfgs) {
|
||||||
auto fg = node_cfg->context()->func_graph();
|
auto fg = node_cfg->context()->func_graph();
|
||||||
// the graph is already output, skip it
|
// the graph is already output, skip it
|
||||||
if (exported.find(fg) != exported.end()) {
|
if (exported.find(fg) != exported.end()) {
|
||||||
|
@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
|
||||||
ofs.close();
|
ofs.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetInferStackInfo(std::ostringstream& oss) {
|
void GetInferStackInfo(std::ostringstream &oss) {
|
||||||
MS_LOG(INFO) << "Get graph analysis information begin";
|
MS_LOG(INFO) << "Get graph analysis information begin";
|
||||||
auto stack = GetCNodeDebugStack();
|
auto stack = GetCNodeDebugStack();
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
|
@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) {
|
||||||
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
|
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
|
||||||
// trace the cnode infer debug info
|
// trace the cnode infer debug info
|
||||||
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
|
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
|
||||||
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) {
|
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
|
||||||
if (eval == nullptr) {
|
if (eval == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
|
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
|
||||||
}
|
}
|
||||||
|
@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
|
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) {
|
||||||
if (eval == nullptr) {
|
if (eval == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
|
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
|
||||||
}
|
}
|
||||||
|
@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); }
|
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }
|
||||||
|
|
||||||
void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); }
|
void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); }
|
||||||
|
|
||||||
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; }
|
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; }
|
||||||
|
|
||||||
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() {
|
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() {
|
||||||
return graph_infer_stack;
|
return graph_infer_stack;
|
||||||
}
|
}
|
||||||
void ClearTraceStack() {
|
void ClearTraceStack() {
|
||||||
|
|
|
@ -31,19 +31,19 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace trace {
|
namespace trace {
|
||||||
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine);
|
std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine);
|
||||||
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix,
|
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
|
||||||
SourceLineTip tip = kSourceLineTipNextLine);
|
SourceLineTip tip = kSourceLineTipNextLine);
|
||||||
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info);
|
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
|
||||||
void TraceGraphInfer();
|
void TraceGraphInfer();
|
||||||
void GetInferStackInfo(std::ostringstream& oss);
|
void GetInferStackInfo(std::ostringstream &oss);
|
||||||
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node);
|
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node);
|
||||||
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval);
|
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval);
|
||||||
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg);
|
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg);
|
||||||
void TraceInferCNodeLeave();
|
void TraceInferCNodeLeave();
|
||||||
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack();
|
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack();
|
||||||
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack();
|
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack();
|
||||||
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs);
|
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs);
|
||||||
void ClearTraceStack();
|
void ClearTraceStack();
|
||||||
} // namespace trace
|
} // namespace trace
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
#include "pipeline/parse/python_adapter.h"
|
#include "pipeline/parse/python_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) {
|
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {
|
||||||
if (info == nullptr) {
|
if (info == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr<DebugInfo>;
|
||||||
// namespace to support intermediate representation definition
|
// namespace to support intermediate representation definition
|
||||||
class TraceInfo : public Base {
|
class TraceInfo : public Base {
|
||||||
public:
|
public:
|
||||||
TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) {
|
TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) {
|
||||||
symbol_ = symbol;
|
symbol_ = symbol;
|
||||||
full_name_ = full_name;
|
full_name_ = full_name;
|
||||||
name_ = full_name_;
|
name_ = full_name_;
|
||||||
debug_info_ = info;
|
debug_info_ = info;
|
||||||
}
|
}
|
||||||
TraceInfo(const TraceInfo& info)
|
TraceInfo(const TraceInfo &info)
|
||||||
: Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {}
|
: Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {}
|
||||||
virtual ~TraceInfo() = default;
|
virtual ~TraceInfo() = default;
|
||||||
MS_DECLARE_PARENT(TraceInfo, Base);
|
MS_DECLARE_PARENT(TraceInfo, Base);
|
||||||
|
@ -55,8 +55,8 @@ class TraceInfo : public Base {
|
||||||
virtual std::string full_name() { return full_name_; }
|
virtual std::string full_name() { return full_name_; }
|
||||||
virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); }
|
virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); }
|
||||||
virtual std::string action_name() { return ""; }
|
virtual std::string action_name() { return ""; }
|
||||||
virtual std::string GetActionBetweenNode(const DebugInfoPtr& info);
|
virtual std::string GetActionBetweenNode(const DebugInfoPtr &info);
|
||||||
void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; }
|
void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; }
|
||||||
DebugInfoPtr debug_info() { return debug_info_; }
|
DebugInfoPtr debug_info() { return debug_info_; }
|
||||||
DebugInfoPtr DebugInfoHasLoc();
|
DebugInfoPtr DebugInfoHasLoc();
|
||||||
std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo();
|
std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo();
|
||||||
|
@ -70,7 +70,7 @@ class TraceInfo : public Base {
|
||||||
|
|
||||||
class TracePhi : public TraceInfo {
|
class TracePhi : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {}
|
explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {}
|
||||||
MS_DECLARE_PARENT(TracePhi, TraceInfo);
|
MS_DECLARE_PARENT(TracePhi, TraceInfo);
|
||||||
~TracePhi() override = default;
|
~TracePhi() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); }
|
||||||
|
@ -78,8 +78,8 @@ class TracePhi : public TraceInfo {
|
||||||
|
|
||||||
class TraceIfStmtTrueBranch : public TraceInfo {
|
class TraceIfStmtTrueBranch : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default;
|
TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default;
|
||||||
explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "✓") {}
|
explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {}
|
||||||
MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo);
|
MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo);
|
||||||
~TraceIfStmtTrueBranch() override = default;
|
~TraceIfStmtTrueBranch() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo {
|
||||||
|
|
||||||
class TraceIfStmtFalseBranch : public TraceInfo {
|
class TraceIfStmtFalseBranch : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default;
|
TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default;
|
||||||
explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "✗") {}
|
explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {}
|
||||||
MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo);
|
MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo);
|
||||||
~TraceIfStmtFalseBranch() override = default;
|
~TraceIfStmtFalseBranch() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo {
|
||||||
|
|
||||||
class TraceIfStmtAfterBranch : public TraceInfo {
|
class TraceIfStmtAfterBranch : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "↓") {}
|
explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {}
|
||||||
MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo);
|
MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo);
|
||||||
~TraceIfStmtAfterBranch() override = default;
|
~TraceIfStmtAfterBranch() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo {
|
||||||
|
|
||||||
class TraceIfExpTrueBranch : public TraceInfo {
|
class TraceIfExpTrueBranch : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "↰") {}
|
explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {}
|
||||||
MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo);
|
MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo);
|
||||||
~TraceIfExpTrueBranch() override = default;
|
~TraceIfExpTrueBranch() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo {
|
||||||
|
|
||||||
class TraceIfExpFalseBranch : public TraceInfo {
|
class TraceIfExpFalseBranch : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "↱") {}
|
explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {}
|
||||||
MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo);
|
MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo);
|
||||||
~TraceIfExpFalseBranch() override = default;
|
~TraceIfExpFalseBranch() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo {
|
||||||
class TraceCopy : public TraceInfo {
|
class TraceCopy : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceCopy() : TraceInfo(nullptr, "copy", "") {}
|
TraceCopy() : TraceInfo(nullptr, "copy", "") {}
|
||||||
explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {}
|
explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {}
|
||||||
MS_DECLARE_PARENT(TraceCopy, TraceInfo);
|
MS_DECLARE_PARENT(TraceCopy, TraceInfo);
|
||||||
~TraceCopy() override = default;
|
~TraceCopy() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); }
|
||||||
|
@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo {
|
||||||
|
|
||||||
class TraceIterator : public TraceInfo {
|
class TraceIterator : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {}
|
explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {}
|
||||||
MS_DECLARE_PARENT(TraceIterator, TraceInfo);
|
MS_DECLARE_PARENT(TraceIterator, TraceInfo);
|
||||||
~TraceIterator() override = default;
|
~TraceIterator() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); }
|
||||||
|
@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo {
|
||||||
|
|
||||||
class TraceWhileHeader : public TraceInfo {
|
class TraceWhileHeader : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "⤾") {}
|
explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {}
|
||||||
MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo);
|
MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo);
|
||||||
~TraceWhileHeader() override = default;
|
~TraceWhileHeader() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); }
|
||||||
|
@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo {
|
||||||
|
|
||||||
class TraceWhileBody : public TraceInfo {
|
class TraceWhileBody : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "⥁") {}
|
explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {}
|
||||||
MS_DECLARE_PARENT(TraceWhileBody, TraceInfo);
|
MS_DECLARE_PARENT(TraceWhileBody, TraceInfo);
|
||||||
~TraceWhileBody() override = default;
|
~TraceWhileBody() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); }
|
||||||
|
@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo {
|
||||||
|
|
||||||
class TraceWhileAfter : public TraceInfo {
|
class TraceWhileAfter : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "↓") {}
|
explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {}
|
||||||
MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo);
|
MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo);
|
||||||
~TraceWhileAfter() override = default;
|
~TraceWhileAfter() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); }
|
||||||
|
@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo {
|
||||||
|
|
||||||
class TraceForHeader : public TraceInfo {
|
class TraceForHeader : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "⤾") {}
|
explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {}
|
||||||
MS_DECLARE_PARENT(TraceForHeader, TraceInfo);
|
MS_DECLARE_PARENT(TraceForHeader, TraceInfo);
|
||||||
~TraceForHeader() override = default;
|
~TraceForHeader() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); }
|
||||||
|
@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo {
|
||||||
|
|
||||||
class TraceForBody : public TraceInfo {
|
class TraceForBody : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "⥁") {}
|
explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {}
|
||||||
MS_DECLARE_PARENT(TraceForBody, TraceInfo);
|
MS_DECLARE_PARENT(TraceForBody, TraceInfo);
|
||||||
~TraceForBody() override = default;
|
~TraceForBody() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); }
|
||||||
|
@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo {
|
||||||
|
|
||||||
class TraceForAfter : public TraceInfo {
|
class TraceForAfter : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "↓") {}
|
explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {}
|
||||||
MS_DECLARE_PARENT(TraceForAfter, TraceInfo);
|
MS_DECLARE_PARENT(TraceForAfter, TraceInfo);
|
||||||
~TraceForAfter() override = default;
|
~TraceForAfter() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); }
|
||||||
|
@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo {
|
||||||
|
|
||||||
class TraceEquiv : public TraceInfo {
|
class TraceEquiv : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {}
|
explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {}
|
||||||
MS_DECLARE_PARENT(TraceEquiv, TraceInfo);
|
MS_DECLARE_PARENT(TraceEquiv, TraceInfo);
|
||||||
~TraceEquiv() override = default;
|
~TraceEquiv() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); }
|
||||||
|
@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo {
|
||||||
class TraceGradFpropApp : public TraceInfo {
|
class TraceGradFpropApp : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {}
|
TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {}
|
||||||
explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "▲") {}
|
explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {}
|
||||||
MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo);
|
MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo);
|
||||||
~TraceGradFpropApp() override = default;
|
~TraceGradFpropApp() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); }
|
||||||
|
@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo {
|
||||||
class TraceGradBpropApp : public TraceInfo {
|
class TraceGradBpropApp : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {}
|
TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {}
|
||||||
explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "▼") {}
|
explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {}
|
||||||
MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo);
|
MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo);
|
||||||
~TraceGradBpropApp() override = default;
|
~TraceGradBpropApp() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); }
|
||||||
|
@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo {
|
||||||
class TraceGradFprop : public TraceInfo {
|
class TraceGradFprop : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {}
|
TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {}
|
||||||
explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "▶") {}
|
explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {}
|
||||||
MS_DECLARE_PARENT(TraceGradFprop, TraceInfo);
|
MS_DECLARE_PARENT(TraceGradFprop, TraceInfo);
|
||||||
~TraceGradFprop() override = default;
|
~TraceGradFprop() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); }
|
||||||
|
@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo {
|
||||||
class TraceGradBprop : public TraceInfo {
|
class TraceGradBprop : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {}
|
TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {}
|
||||||
explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "◀") {}
|
explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {}
|
||||||
MS_DECLARE_PARENT(TraceGradBprop, TraceInfo);
|
MS_DECLARE_PARENT(TraceGradBprop, TraceInfo);
|
||||||
~TraceGradBprop() override = default;
|
~TraceGradBprop() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); }
|
||||||
|
@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo {
|
||||||
class TraceGradSens : public TraceInfo {
|
class TraceGradSens : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {}
|
TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {}
|
||||||
explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "∇") {}
|
explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {}
|
||||||
MS_DECLARE_PARENT(TraceGradSens, TraceInfo);
|
MS_DECLARE_PARENT(TraceGradSens, TraceInfo);
|
||||||
~TraceGradSens() override = default;
|
~TraceGradSens() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); }
|
||||||
|
@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo {
|
||||||
|
|
||||||
class TraceSpecialize : public TraceInfo {
|
class TraceSpecialize : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; }
|
explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; }
|
||||||
MS_DECLARE_PARENT(TraceSpecialize, TraceInfo);
|
MS_DECLARE_PARENT(TraceSpecialize, TraceInfo);
|
||||||
std::string name() override { return full_name_ + counter_; }
|
std::string name() override { return full_name_ + counter_; }
|
||||||
std::string symbol() override { return counter_ + "_"; }
|
std::string symbol() override { return counter_ + "_"; }
|
||||||
|
@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo {
|
||||||
|
|
||||||
class TraceGradOperation : public TraceInfo {
|
class TraceGradOperation : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {}
|
explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {}
|
||||||
MS_DECLARE_PARENT(TraceGradOperation, TraceInfo);
|
MS_DECLARE_PARENT(TraceGradOperation, TraceInfo);
|
||||||
~TraceGradOperation() override = default;
|
~TraceGradOperation() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo {
|
||||||
|
|
||||||
class TraceForceBool : public TraceInfo {
|
class TraceForceBool : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {}
|
explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {}
|
||||||
MS_DECLARE_PARENT(TraceForceBool, TraceInfo);
|
MS_DECLARE_PARENT(TraceForceBool, TraceInfo);
|
||||||
~TraceForceBool() override = default;
|
~TraceForceBool() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
|
||||||
|
@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo {
|
||||||
|
|
||||||
class TraceExpandJ : public TraceInfo {
|
class TraceExpandJ : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {}
|
explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {}
|
||||||
MS_DECLARE_PARENT(TraceExpandJ, TraceInfo);
|
MS_DECLARE_PARENT(TraceExpandJ, TraceInfo);
|
||||||
~TraceExpandJ() override = default;
|
~TraceExpandJ() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); }
|
||||||
|
@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo {
|
||||||
|
|
||||||
class TraceGenMetaFuncGraph : public TraceInfo {
|
class TraceGenMetaFuncGraph : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {}
|
explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {}
|
||||||
MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo);
|
MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo);
|
||||||
~TraceGenMetaFuncGraph() override = default;
|
~TraceGenMetaFuncGraph() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo {
|
||||||
|
|
||||||
class TraceEvaluatorGenGraph : public TraceInfo {
|
class TraceEvaluatorGenGraph : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {}
|
explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {}
|
||||||
MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo);
|
MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo);
|
||||||
~TraceEvaluatorGenGraph() override = default;
|
~TraceEvaluatorGenGraph() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo {
|
||||||
|
|
||||||
class TraceResolve : public TraceInfo {
|
class TraceResolve : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {}
|
explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {}
|
||||||
MS_DECLARE_PARENT(TraceResolve, TraceInfo);
|
MS_DECLARE_PARENT(TraceResolve, TraceInfo);
|
||||||
~TraceResolve() override = default;
|
~TraceResolve() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); }
|
||||||
|
@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo {
|
||||||
class TraceTransform : public TraceInfo {
|
class TraceTransform : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; }
|
TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; }
|
||||||
explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") {
|
explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") {
|
||||||
transform_name_ = transform_name;
|
transform_name_ = transform_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo {
|
||||||
|
|
||||||
class TraceGenerateVarArg : public TraceInfo {
|
class TraceGenerateVarArg : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {}
|
explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {}
|
||||||
MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo);
|
MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo);
|
||||||
~TraceGenerateVarArg() override = default;
|
~TraceGenerateVarArg() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo {
|
||||||
|
|
||||||
class TraceGenerateKwArg : public TraceInfo {
|
class TraceGenerateKwArg : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {}
|
explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {}
|
||||||
MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo);
|
MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo);
|
||||||
~TraceGenerateKwArg() override = default;
|
~TraceGenerateKwArg() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo {
|
||||||
|
|
||||||
class TraceTrasformK : public TraceInfo {
|
class TraceTrasformK : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {}
|
explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {}
|
||||||
MS_DECLARE_PARENT(TraceTrasformK, TraceInfo);
|
MS_DECLARE_PARENT(TraceTrasformK, TraceInfo);
|
||||||
~TraceTrasformK() override = default;
|
~TraceTrasformK() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); }
|
||||||
|
@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo {
|
||||||
|
|
||||||
class TracePartialTransform : public TraceInfo {
|
class TracePartialTransform : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {}
|
explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {}
|
||||||
MS_DECLARE_PARENT(TracePartialTransform, TraceInfo);
|
MS_DECLARE_PARENT(TracePartialTransform, TraceInfo);
|
||||||
~TracePartialTransform() override = default;
|
~TracePartialTransform() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo {
|
||||||
|
|
||||||
class TraceGetEnv : public TraceInfo {
|
class TraceGetEnv : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {}
|
explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {}
|
||||||
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
|
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
|
||||||
~TraceGetEnv() override = default;
|
~TraceGetEnv() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); }
|
||||||
|
@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo {
|
||||||
|
|
||||||
class TraceDoSignature : public TraceInfo {
|
class TraceDoSignature : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {}
|
explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {}
|
||||||
MS_DECLARE_PARENT(TraceDoSignature, TraceInfo);
|
MS_DECLARE_PARENT(TraceDoSignature, TraceInfo);
|
||||||
~TraceDoSignature() override = default;
|
~TraceDoSignature() override = default;
|
||||||
TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); }
|
TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); }
|
||||||
|
@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo {
|
||||||
class TraceCombileLikeGraphs : public TraceInfo {
|
class TraceCombileLikeGraphs : public TraceInfo {
|
||||||
public:
|
public:
|
||||||
TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {}
|
TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {}
|
||||||
explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {}
|
explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {}
|
||||||
MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo);
|
MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo);
|
||||||
~TraceCombileLikeGraphs() override = default;
|
~TraceCombileLikeGraphs() override = default;
|
||||||
TraceInfoPtr clone() override {
|
TraceInfoPtr clone() override {
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
|
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
|
||||||
if (has_malloc_) {
|
if (has_malloc_) {
|
||||||
MS_LOG(EXCEPTION) << "Has alloc memory pool memory !";
|
MS_LOG(EXCEPTION) << "Has alloc memory pool memory !";
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) {
|
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
|
||||||
MS_EXCEPTION_IF_NULL(addr);
|
MS_EXCEPTION_IF_NULL(addr);
|
||||||
has_malloc_ = false;
|
has_malloc_ = false;
|
||||||
free_mem_size_ = total_mem_size_;
|
free_mem_size_ = total_mem_size_;
|
||||||
|
@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
|
||||||
|
|
||||||
size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; }
|
size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; }
|
||||||
|
|
||||||
void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) {
|
void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) {
|
||||||
MS_EXCEPTION_IF_NULL(device_mem_pool_base);
|
MS_EXCEPTION_IF_NULL(device_mem_pool_base);
|
||||||
device_mem_pool_base_ = device_mem_pool_base;
|
device_mem_pool_base_ = device_mem_pool_base;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,12 +26,12 @@ namespace ascend {
|
||||||
class AscendMemoryPool : public DynamicMemPoolBestFit {
|
class AscendMemoryPool : public DynamicMemPoolBestFit {
|
||||||
public:
|
public:
|
||||||
~AscendMemoryPool() override = default;
|
~AscendMemoryPool() override = default;
|
||||||
AscendMemoryPool(const AscendMemoryPool&) = delete;
|
AscendMemoryPool(const AscendMemoryPool &) = delete;
|
||||||
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete;
|
AscendMemoryPool &operator=(const AscendMemoryPool &) = delete;
|
||||||
|
|
||||||
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
|
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
|
||||||
bool FreeDeviceMem(const DeviceMemPtr& addr) override;
|
bool FreeDeviceMem(const DeviceMemPtr &addr) override;
|
||||||
void set_device_mem_pool_base(uint8_t* device_mem_pool_base);
|
void set_device_mem_pool_base(uint8_t *device_mem_pool_base);
|
||||||
void set_device_mem_pool_size(uint64_t device_mem_pool_size) {
|
void set_device_mem_pool_size(uint64_t device_mem_pool_size) {
|
||||||
device_mem_pool_size_ = device_mem_pool_size;
|
device_mem_pool_size_ = device_mem_pool_size;
|
||||||
free_mem_size_ = device_mem_pool_size_;
|
free_mem_size_ = device_mem_pool_size_;
|
||||||
|
@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
|
||||||
size_t free_mem_size() override;
|
size_t free_mem_size() override;
|
||||||
size_t total_mem_size() override;
|
size_t total_mem_size() override;
|
||||||
|
|
||||||
static AscendMemoryPool& GetInstance() {
|
static AscendMemoryPool &GetInstance() {
|
||||||
static AscendMemoryPool instance;
|
static AscendMemoryPool instance;
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
|
||||||
private:
|
private:
|
||||||
AscendMemoryPool() = default;
|
AscendMemoryPool() = default;
|
||||||
bool has_malloc_{false};
|
bool has_malloc_{false};
|
||||||
uint8_t* device_mem_pool_base_{nullptr};
|
uint8_t *device_mem_pool_base_{nullptr};
|
||||||
uint64_t device_mem_pool_size_{0};
|
uint64_t device_mem_pool_size_{0};
|
||||||
size_t free_mem_size_{0};
|
size_t free_mem_size_{0};
|
||||||
size_t total_mem_size_{0};
|
size_t total_mem_size_{0};
|
||||||
|
|
|
@ -39,13 +39,13 @@ using std::vector;
|
||||||
|
|
||||||
class AscendStreamAssign {
|
class AscendStreamAssign {
|
||||||
public:
|
public:
|
||||||
static AscendStreamAssign& GetInstance() {
|
static AscendStreamAssign &GetInstance() {
|
||||||
static AscendStreamAssign instance; // Guaranteed to be destroyed.
|
static AscendStreamAssign instance; // Guaranteed to be destroyed.
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
AscendStreamAssign(const AscendStreamAssign&) = delete;
|
AscendStreamAssign(const AscendStreamAssign &) = delete;
|
||||||
AscendStreamAssign& operator=(const AscendStreamAssign&) = delete;
|
AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
|
||||||
|
|
||||||
uint32_t GetTotalStreamNum() const;
|
uint32_t GetTotalStreamNum() const;
|
||||||
// new stream policy
|
// new stream policy
|
||||||
|
@ -53,19 +53,19 @@ class AscendStreamAssign {
|
||||||
uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
|
uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
|
||||||
uint32_t total_event_num() const { return total_event_num_; }
|
uint32_t total_event_num() const { return total_event_num_; }
|
||||||
|
|
||||||
void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void ResetNew();
|
void ResetNew();
|
||||||
void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
bool IsIndependentNode(const CNodePtr& node_ptr);
|
bool IsIndependentNode(const CNodePtr &node_ptr);
|
||||||
const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; }
|
const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; }
|
||||||
const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; }
|
const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; }
|
||||||
const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; }
|
const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; }
|
||||||
void GetWaitStreams(vector<uint32_t>* wait_active_stream_list);
|
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
|
||||||
const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; }
|
const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; }
|
||||||
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id,
|
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
|
||||||
uint32_t stream_id);
|
uint32_t stream_id);
|
||||||
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id,
|
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
|
||||||
uint32_t stream_id);
|
uint32_t stream_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -73,30 +73,30 @@ class AscendStreamAssign {
|
||||||
~AscendStreamAssign() = default;
|
~AscendStreamAssign() = default;
|
||||||
|
|
||||||
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
|
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
|
||||||
const CNodePtr& node);
|
const CNodePtr &node);
|
||||||
|
|
||||||
bool IsHcom(const CNodePtr& apply_kernel);
|
bool IsHcom(const CNodePtr &apply_kernel);
|
||||||
bool IsProcessed(uint32_t logic_id);
|
bool IsProcessed(uint32_t logic_id);
|
||||||
void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids);
|
void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids);
|
||||||
void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index,
|
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index,
|
||||||
uint32_t* cur_stream_id);
|
uint32_t *cur_stream_id);
|
||||||
void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
|
void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
|
||||||
void UpdateStreamActive(const CNodePtr& active_ptr);
|
void UpdateStreamActive(const CNodePtr &active_ptr);
|
||||||
void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr);
|
void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr);
|
||||||
bool IsTaskSink();
|
bool IsTaskSink();
|
||||||
void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id);
|
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id);
|
||||||
void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
|
void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
|
||||||
uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr);
|
uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr);
|
||||||
void SetCommonStreamNum(uint32_t cur_stream_id);
|
void SetCommonStreamNum(uint32_t cur_stream_id);
|
||||||
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
bool IsProcessedParallelStream(uint32_t stream_id);
|
bool IsProcessedParallelStream(uint32_t stream_id);
|
||||||
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams);
|
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
|
||||||
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr);
|
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
|
|
||||||
uint32_t total_common_stream_num_{0};
|
uint32_t total_common_stream_num_{0};
|
||||||
uint32_t total_independ_stream_num_{0};
|
uint32_t total_independ_stream_num_{0};
|
||||||
|
|
|
@ -28,14 +28,14 @@ namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
class PluginImpl : public PluginIntf {
|
class PluginImpl : public PluginIntf {
|
||||||
public:
|
public:
|
||||||
explicit PluginImpl(const std::string& module);
|
explicit PluginImpl(const std::string &module);
|
||||||
~PluginImpl() override = default;
|
~PluginImpl() override = default;
|
||||||
int Init(const Reporter* reporter) override;
|
int Init(const Reporter *reporter) override;
|
||||||
int UnInit() override;
|
int UnInit() override;
|
||||||
static Reporter* GetPluginReporter() { return reporter_; }
|
static Reporter *GetPluginReporter() { return reporter_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static Reporter* reporter_;
|
static Reporter *reporter_;
|
||||||
std::string module_;
|
std::string module_;
|
||||||
};
|
};
|
||||||
} // namespace ascend
|
} // namespace ascend
|
||||||
|
|
|
@ -20,12 +20,12 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
PluginIntf* ProfilingEngineImpl::CreatePlugin() {
|
PluginIntf *ProfilingEngineImpl::CreatePlugin() {
|
||||||
MS_LOG(INFO) << "Create Plugin.";
|
MS_LOG(INFO) << "Create Plugin.";
|
||||||
return new (std::nothrow) PluginImpl("Framework");
|
return new (std::nothrow) PluginImpl("Framework");
|
||||||
}
|
}
|
||||||
|
|
||||||
int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) {
|
int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) {
|
||||||
if (plugin != nullptr) {
|
if (plugin != nullptr) {
|
||||||
delete plugin;
|
delete plugin;
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf {
|
||||||
ProfilingEngineImpl() = default;
|
ProfilingEngineImpl() = default;
|
||||||
~ProfilingEngineImpl() override = default;
|
~ProfilingEngineImpl() override = default;
|
||||||
|
|
||||||
PluginIntf* CreatePlugin() override;
|
PluginIntf *CreatePlugin() override;
|
||||||
int ReleasePlugin(PluginIntf* plugin) override;
|
int ReleasePlugin(PluginIntf *plugin) override;
|
||||||
};
|
};
|
||||||
} // namespace ascend
|
} // namespace ascend
|
||||||
} // namespace device
|
} // namespace device
|
||||||
|
|
|
@ -35,7 +35,7 @@ using Json = nlohmann::json;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
ProfilingManager& ProfilingManager::GetInstance() {
|
ProfilingManager &ProfilingManager::GetInstance() {
|
||||||
static ProfilingManager inst;
|
static ProfilingManager inst;
|
||||||
return inst;
|
return inst;
|
||||||
}
|
}
|
||||||
|
@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t ProfilingManager::GetJobId() const {
|
uint64_t ProfilingManager::GetJobId() const {
|
||||||
const char* job_id = std::getenv("JOB_ID");
|
const char *job_id = std::getenv("JOB_ID");
|
||||||
return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0);
|
return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskId_map) const {
|
bool ProfilingManager::ReportProfilingData(const map<uint32_t, string> &op_taskId_map) const {
|
||||||
if (!IsProfiling()) {
|
if (!IsProfiling()) {
|
||||||
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
|
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
|
||||||
MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size();
|
MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size();
|
||||||
|
|
||||||
Msprof::Engine::ReporterData reporter_data = {};
|
Msprof::Engine::ReporterData reporter_data = {};
|
||||||
for (const auto& iter : op_taskId_map) {
|
for (const auto &iter : op_taskId_map) {
|
||||||
auto data = iter.second + ' ' + std::to_string(iter.first) + ';';
|
auto data = iter.second + ' ' + std::to_string(iter.first) + ';';
|
||||||
reporter_data.deviceId = UintToInt(device_id_);
|
reporter_data.deviceId = UintToInt(device_id_);
|
||||||
reporter_data.data = (unsigned char*)(const_cast<char*>(data.c_str()));
|
reporter_data.data = (unsigned char *)(const_cast<char *>(data.c_str()));
|
||||||
reporter_data.dataLen = data.size();
|
reporter_data.dataLen = data.size();
|
||||||
auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework"));
|
auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework"));
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
|
@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<std::string> Split(const std::string& str, const char delim) {
|
static std::vector<std::string> Split(const std::string &str, const char delim) {
|
||||||
std::vector<std::string> elems;
|
std::vector<std::string> elems;
|
||||||
|
|
||||||
if (str.empty()) {
|
if (str.empty()) {
|
||||||
|
@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
|
||||||
device_id_ = device_id;
|
device_id_ = device_id;
|
||||||
// exp: export PROFILING_MODE=true
|
// exp: export PROFILING_MODE=true
|
||||||
// export PROFILING_OPTIONS=training_trace
|
// export PROFILING_OPTIONS=training_trace
|
||||||
const char* prof_options_str = std::getenv("PROFILING_OPTIONS");
|
const char *prof_options_str = std::getenv("PROFILING_OPTIONS");
|
||||||
// register Framework to profiling
|
// register Framework to profiling
|
||||||
int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get());
|
int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get());
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
|
@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const {
|
||||||
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
|
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter();
|
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
|
||||||
if (reporter != nullptr) {
|
if (reporter != nullptr) {
|
||||||
MS_LOG(INFO) << "report data end, ret = " << reporter->Flush();
|
MS_LOG(INFO) << "report data end, ret = " << reporter->Flush();
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,
|
||||||
|
|
||||||
class GpuQueue {
|
class GpuQueue {
|
||||||
public:
|
public:
|
||||||
GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity);
|
GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity);
|
||||||
virtual ~GpuQueue();
|
virtual ~GpuQueue();
|
||||||
|
|
||||||
void RegisterRelease(const std::function<void(void*)>& func) { host_release_ = func; }
|
void RegisterRelease(const std::function<void(void *)> &func) { host_release_ = func; }
|
||||||
|
|
||||||
inline bool IsEmpty() const { return head_ == tail_; }
|
inline bool IsEmpty() const { return head_ == tail_; }
|
||||||
inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); }
|
inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); }
|
||||||
|
|
||||||
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size);
|
BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size);
|
||||||
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const;
|
BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const;
|
||||||
BlockQueueStatus_T Pop();
|
BlockQueueStatus_T Pop();
|
||||||
bool Destroy();
|
bool Destroy();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct NodeInfo {
|
struct NodeInfo {
|
||||||
std::unique_ptr<cudaEvent_t> event_;
|
std::unique_ptr<cudaEvent_t> event_;
|
||||||
void* host_feature_addr_;
|
void *host_feature_addr_;
|
||||||
void* host_label_addr_;
|
void *host_label_addr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* buffer_;
|
void *buffer_;
|
||||||
size_t head_;
|
size_t head_;
|
||||||
size_t tail_;
|
size_t tail_;
|
||||||
size_t feature_size_;
|
size_t feature_size_;
|
||||||
|
@ -61,10 +61,10 @@ class GpuQueue {
|
||||||
size_t capacity_;
|
size_t capacity_;
|
||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
std::unique_ptr<NodeInfo[]> node_info_;
|
std::unique_ptr<NodeInfo[]> node_info_;
|
||||||
std::function<void(void*)> host_release_;
|
std::function<void(void *)> host_release_;
|
||||||
|
|
||||||
GpuQueue(const GpuQueue&) = delete;
|
GpuQueue(const GpuQueue &) = delete;
|
||||||
GpuQueue& operator=(const GpuQueue&) = delete;
|
GpuQueue &operator=(const GpuQueue &) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BlockingQueue {
|
class BlockingQueue {
|
||||||
|
@ -72,11 +72,11 @@ class BlockingQueue {
|
||||||
BlockingQueue() : queue_(nullptr) {}
|
BlockingQueue() : queue_(nullptr) {}
|
||||||
~BlockingQueue() = default;
|
~BlockingQueue() = default;
|
||||||
|
|
||||||
BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity);
|
BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity);
|
||||||
void RegisterRelease(const std::function<void(void*)>& func);
|
void RegisterRelease(const std::function<void(void *)> &func);
|
||||||
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size,
|
BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size,
|
||||||
unsigned int timeout_in_sec);
|
unsigned int timeout_in_sec);
|
||||||
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size);
|
BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size);
|
||||||
BlockQueueStatus_T Pop();
|
BlockQueueStatus_T Pop();
|
||||||
bool Destroy();
|
bool Destroy();
|
||||||
|
|
||||||
|
|
|
@ -20,17 +20,17 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
CollectiveInitializer& CollectiveInitializer::instance() {
|
CollectiveInitializer &CollectiveInitializer::instance() {
|
||||||
static CollectiveInitializer instance = {};
|
static CollectiveInitializer instance = {};
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CollectiveInitializer::collective_inited() const { return collective_inited_; }
|
bool CollectiveInitializer::collective_inited() const { return collective_inited_; }
|
||||||
|
|
||||||
const void* CollectiveInitializer::collective_handle() const { return collective_handle_; }
|
const void *CollectiveInitializer::collective_handle() const { return collective_handle_; }
|
||||||
|
|
||||||
void CollectiveInitializer::InitCollective() {
|
void CollectiveInitializer::InitCollective() {
|
||||||
void* handle = dlopen("libgpu_collective.so", RTLD_LAZY);
|
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
|
||||||
if (handle == nullptr) {
|
if (handle == nullptr) {
|
||||||
MS_LOG(EXCEPTION)
|
MS_LOG(EXCEPTION)
|
||||||
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not "
|
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not "
|
||||||
|
|
|
@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() {
|
||||||
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
|
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPUDeviceManager::CreateStream(DeviceStream* stream) {
|
bool GPUDeviceManager::CreateStream(DeviceStream *stream) {
|
||||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream");
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream");
|
||||||
gpu_streams_.emplace_back(*stream);
|
gpu_streams_.emplace_back(*stream);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; }
|
const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; }
|
||||||
|
|
||||||
int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); }
|
int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); }
|
||||||
|
|
||||||
|
@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; }
|
||||||
|
|
||||||
bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; }
|
bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; }
|
||||||
|
|
||||||
const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
|
const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
|
||||||
|
|
||||||
const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
|
const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
|
||||||
|
|
||||||
bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); }
|
bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); }
|
||||||
|
|
||||||
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const {
|
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const {
|
||||||
return CudaDriver::CopyDeviceMemToHost(dst, src, size);
|
return CudaDriver::CopyDeviceMemToHost(dst, src, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const {
|
bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const {
|
||||||
return CudaDriver::CopyHostMemToDevice(dst, src, size);
|
return CudaDriver::CopyHostMemToDevice(dst, src, size);
|
||||||
}
|
}
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
|
|
@ -37,17 +37,17 @@ class GPUDeviceManager {
|
||||||
uint32_t cur_device_id() const;
|
uint32_t cur_device_id() const;
|
||||||
bool is_device_id_init() const;
|
bool is_device_id_init() const;
|
||||||
|
|
||||||
bool CreateStream(DeviceStream* stream);
|
bool CreateStream(DeviceStream *stream);
|
||||||
bool SyncStream(const DeviceStream& stream) const;
|
bool SyncStream(const DeviceStream &stream) const;
|
||||||
const DeviceStream& default_stream() const;
|
const DeviceStream &default_stream() const;
|
||||||
|
|
||||||
const cudnnHandle_t& GetCudnnHandle() const;
|
const cudnnHandle_t &GetCudnnHandle() const;
|
||||||
const cublasHandle_t& GetCublasHandle() const;
|
const cublasHandle_t &GetCublasHandle() const;
|
||||||
|
|
||||||
bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const;
|
bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const;
|
||||||
bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const;
|
bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const;
|
||||||
|
|
||||||
static GPUDeviceManager& GetInstance() {
|
static GPUDeviceManager &GetInstance() {
|
||||||
static GPUDeviceManager instance;
|
static GPUDeviceManager instance;
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
@ -55,8 +55,8 @@ class GPUDeviceManager {
|
||||||
private:
|
private:
|
||||||
GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {}
|
GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {}
|
||||||
~GPUDeviceManager() = default;
|
~GPUDeviceManager() = default;
|
||||||
GPUDeviceManager(const GPUDeviceManager&) = delete;
|
GPUDeviceManager(const GPUDeviceManager &) = delete;
|
||||||
GPUDeviceManager& operator=(const GPUDeviceManager&) = delete;
|
GPUDeviceManager &operator=(const GPUDeviceManager &) = delete;
|
||||||
|
|
||||||
// default CUDA stream used for all the kernels.
|
// default CUDA stream used for all the kernels.
|
||||||
DeviceStream default_stream_{nullptr};
|
DeviceStream default_stream_{nullptr};
|
||||||
|
|
|
@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) {
|
bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) {
|
||||||
auto alloc_size = AllocDeviceMem(size, addr);
|
auto alloc_size = AllocDeviceMem(size, addr);
|
||||||
buffer_q_addr_ = *addr;
|
buffer_q_addr_ = *addr;
|
||||||
// Buffer queue needs to ensure that the alloc_size and size is equal.
|
// Buffer queue needs to ensure that the alloc_size and size is equal.
|
||||||
return (alloc_size == size) ? true : false;
|
return (alloc_size == size) ? true : false;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
|
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
MS_LOG(EXCEPTION) << "The memory alloc size is 0.";
|
MS_LOG(EXCEPTION) << "The memory alloc size is 0.";
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
|
||||||
return alloc_size;
|
return alloc_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); }
|
bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); }
|
||||||
|
|
||||||
size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); }
|
size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); }
|
||||||
|
|
||||||
|
|
|
@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit {
|
||||||
~GPUMemoryAllocator() override = default;
|
~GPUMemoryAllocator() override = default;
|
||||||
bool Init();
|
bool Init();
|
||||||
bool Finalize();
|
bool Finalize();
|
||||||
bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr);
|
bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr);
|
||||||
|
|
||||||
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
|
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
|
||||||
bool FreeDeviceMem(const DeviceMemPtr& addr) override;
|
bool FreeDeviceMem(const DeviceMemPtr &addr) override;
|
||||||
size_t free_mem_size() override;
|
size_t free_mem_size() override;
|
||||||
size_t total_mem_size() override;
|
size_t total_mem_size() override;
|
||||||
|
|
||||||
static GPUMemoryAllocator& GetInstance() {
|
static GPUMemoryAllocator &GetInstance() {
|
||||||
static GPUMemoryAllocator instance;
|
static GPUMemoryAllocator instance;
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GPUMemoryAllocator() = default;
|
GPUMemoryAllocator() = default;
|
||||||
GPUMemoryAllocator(const GPUMemoryAllocator&) = delete;
|
GPUMemoryAllocator(const GPUMemoryAllocator &) = delete;
|
||||||
GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete;
|
GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete;
|
||||||
|
|
||||||
// Used to track address of data buffer queue.
|
// Used to track address of data buffer queue.
|
||||||
DeviceMemPtr buffer_q_addr_{nullptr};
|
DeviceMemPtr buffer_q_addr_{nullptr};
|
||||||
|
|
|
@ -33,8 +33,8 @@ namespace gpu {
|
||||||
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
|
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
|
||||||
using mindspore::kernel::KernelBuildInfo;
|
using mindspore::kernel::KernelBuildInfo;
|
||||||
namespace {
|
namespace {
|
||||||
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info,
|
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
|
||||||
const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) {
|
const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
|
||||||
MS_EXCEPTION_IF_NULL(selected_kernel_info);
|
MS_EXCEPTION_IF_NULL(selected_kernel_info);
|
||||||
MS_EXCEPTION_IF_NULL(alternative_kernel_info);
|
MS_EXCEPTION_IF_NULL(alternative_kernel_info);
|
||||||
size_t selected_input_num = selected_kernel_info->GetInputNum();
|
size_t selected_input_num = selected_kernel_info->GetInputNum();
|
||||||
|
@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SupportedTypeList(const CNodePtr& kernel_node) {
|
std::string SupportedTypeList(const CNodePtr &kernel_node) {
|
||||||
std::string supported_type_lists =
|
std::string supported_type_lists =
|
||||||
kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node));
|
kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node));
|
||||||
if (!supported_type_lists.empty()) {
|
if (!supported_type_lists.empty()) {
|
||||||
|
@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) {
|
||||||
return supported_type_lists;
|
return supported_type_lists;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) {
|
bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
MS_EXCEPTION_IF_NULL(selected_kernel_info);
|
MS_EXCEPTION_IF_NULL(selected_kernel_info);
|
||||||
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
|
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
|
||||||
|
@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
|
||||||
}
|
}
|
||||||
|
|
||||||
bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
|
bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
|
||||||
[&](const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info) {
|
[&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) {
|
||||||
return CheckKernelInfo(alternative_kernel_info, selected_kernel_info);
|
return CheckKernelInfo(alternative_kernel_info, selected_kernel_info);
|
||||||
});
|
});
|
||||||
if (!match) {
|
if (!match) {
|
||||||
|
@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, const CNodePtr& kernel_node) {
|
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||||
auto input_kernel_node = kernel_node->input(input_index + 1);
|
auto input_kernel_node = kernel_node->input(input_index + 1);
|
||||||
|
@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SetKernelInfo(const CNodePtr& kernel_node) {
|
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
std::vector<std::string> inputs_format;
|
std::vector<std::string> inputs_format;
|
||||||
std::vector<TypeId> inputs_type;
|
std::vector<TypeId> inputs_type;
|
||||||
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
void SetKernelInfo(const CNodePtr& apply_kernel_ptr);
|
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
|
||||||
|
|
||||||
class KernelAttr {
|
class KernelAttr {
|
||||||
public:
|
public:
|
||||||
|
@ -35,24 +35,24 @@ class KernelAttr {
|
||||||
KernelAttr() : all_same_(false) {}
|
KernelAttr() : all_same_(false) {}
|
||||||
~KernelAttr() = default;
|
~KernelAttr() = default;
|
||||||
|
|
||||||
KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) {
|
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
|
||||||
input_type_.emplace_back(ms_type, format);
|
input_type_.emplace_back(ms_type, format);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) {
|
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
|
||||||
output_type_.emplace_back(ms_type, format);
|
output_type_.emplace_back(ms_type, format);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelAttr& AddAllSameAttr(const bool& all_same) {
|
KernelAttr &AddAllSameAttr(const bool &all_same) {
|
||||||
all_same_ = all_same;
|
all_same_ = all_same;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; }
|
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
|
||||||
const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; }
|
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
|
||||||
const bool& GetAllSame() const { return all_same_; }
|
const bool &GetAllSame() const { return all_same_; }
|
||||||
|
|
||||||
size_t GetInputSize() const { return input_type_.size(); }
|
size_t GetInputSize() const { return input_type_.size(); }
|
||||||
size_t GetOutputSize() const { return output_type_.size(); }
|
size_t GetOutputSize() const { return output_type_.size(); }
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
struct TypeIdManager* TypeIdManager::Get() {
|
struct TypeIdManager *TypeIdManager::Get() {
|
||||||
static TypeIdManager manager;
|
static TypeIdManager manager;
|
||||||
return &manager;
|
return &manager;
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra
|
||||||
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
|
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
|
||||||
|
|
||||||
std::string AnfNode::ToString() const {
|
std::string AnfNode::ToString() const {
|
||||||
return mindspore::label_manage::Label(const_cast<AnfNode*>(this)->shared_from_base<AnfNode>()->debug_info());
|
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
CNode::CNode(const std::vector<AnfNodePtr>& inputs, const FuncGraphPtr& func_graph)
|
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
|
||||||
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
|
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
|
||||||
|
|
||||||
// Check if CNode is an apply with the specific Primitive.
|
// Check if CNode is an apply with the specific Primitive.
|
||||||
bool CNode::IsApply(const PrimitivePtr& value) const {
|
bool CNode::IsApply(const PrimitivePtr &value) const {
|
||||||
if (value == nullptr) {
|
if (value == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; }
|
void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; }
|
||||||
|
|
||||||
std::string CNode::DebugString(int recursive_level) const {
|
std::string CNode::DebugString(int recursive_level) const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
|
@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const {
|
||||||
buffer << ToString() << "{";
|
buffer << ToString() << "{";
|
||||||
bool is_first_node = true;
|
bool is_first_node = true;
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
for (auto& node : inputs_) {
|
for (auto &node : inputs_) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (is_first_node) {
|
if (is_first_node) {
|
||||||
is_first_node = false;
|
is_first_node = false;
|
||||||
|
@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) {
|
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
|
||||||
if (operator_info_ != nullptr) {
|
if (operator_info_ != nullptr) {
|
||||||
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
|
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
|
||||||
<< ", using the new one: " << operator_info->name();
|
<< ", using the new one: " << operator_info->name();
|
||||||
|
@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() {
|
||||||
return fullname_with_scope_;
|
return fullname_with_scope_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<CNode>()); }
|
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
|
||||||
void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<ValueNode>()); }
|
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
|
||||||
void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base<Parameter>()); }
|
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
|
||||||
|
|
||||||
bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) {
|
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
if (cnode != nullptr) {
|
if (cnode != nullptr) {
|
||||||
|
@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) {
|
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
|
bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
|
||||||
if (IsValueNode<Primitive>(node)) {
|
if (IsValueNode<Primitive>(node)) {
|
||||||
PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
|
PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
|
||||||
}
|
}
|
||||||
namespace id_generator {
|
namespace id_generator {
|
||||||
static std::unordered_map<std::string, int> node_ids;
|
static std::unordered_map<std::string, int> node_ids;
|
||||||
std::string get_id(const AnfNodePtr& node) {
|
std::string get_id(const AnfNodePtr &node) {
|
||||||
auto type_name = node->type_name();
|
auto type_name = node->type_name();
|
||||||
if (node_ids.find(type_name) == node_ids.end()) {
|
if (node_ids.find(type_name) == node_ids.end()) {
|
||||||
node_ids[type_name] = 0;
|
node_ids[type_name] = 0;
|
||||||
|
|
|
@ -39,15 +39,15 @@ struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
|
||||||
class Base : public std::enable_shared_from_this<Base> {
|
class Base : public std::enable_shared_from_this<Base> {
|
||||||
public:
|
public:
|
||||||
constexpr Base() = default;
|
constexpr Base() = default;
|
||||||
Base(const Base& other) : std::enable_shared_from_this<Base>(other) {}
|
Base(const Base &other) : std::enable_shared_from_this<Base>(other) {}
|
||||||
virtual bool operator==(const Base& rhs) {
|
virtual bool operator==(const Base &rhs) {
|
||||||
if (this == &rhs) {
|
if (this == &rhs) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual Base& operator=(const Base&) { return *this; }
|
virtual Base &operator=(const Base &) { return *this; }
|
||||||
virtual ~Base() = default;
|
virtual ~Base() = default;
|
||||||
virtual std::size_t hash() const { return tid(); }
|
virtual std::size_t hash() const { return tid(); }
|
||||||
virtual std::string ToString() const { return type_name(); }
|
virtual std::string ToString() const { return type_name(); }
|
||||||
|
@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this<Base> {
|
||||||
|
|
||||||
virtual const bool IsFromTypeId(uint32_t tid) const;
|
virtual const bool IsFromTypeId(uint32_t tid) const;
|
||||||
virtual std::string type_name() const { return "Base"; }
|
virtual std::string type_name() const { return "Base"; }
|
||||||
static uint32_t GetTypeId(const char* const type_key);
|
static uint32_t GetTypeId(const char *const type_key);
|
||||||
virtual uint32_t tid() const {
|
virtual uint32_t tid() const {
|
||||||
static const uint32_t tid = GetTypeId(typeid(Base).name());
|
static const uint32_t tid = GetTypeId(typeid(Base).name());
|
||||||
return tid;
|
return tid;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T,
|
template <typename T,
|
||||||
typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type* = nullptr>
|
typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type * = nullptr>
|
||||||
inline bool isa() const {
|
inline bool isa() const {
|
||||||
static const uint32_t tid = GetTypeId(typeid(T).name());
|
static const uint32_t tid = GetTypeId(typeid(T).name());
|
||||||
return this->IsFromTypeId(tid);
|
return this->IsFromTypeId(tid);
|
||||||
|
@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr<Base>;
|
||||||
using BaseWeakPtr = std::weak_ptr<Base>;
|
using BaseWeakPtr = std::weak_ptr<Base>;
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
inline T* cast(U* source) {
|
inline T *cast(U *source) {
|
||||||
if (source != nullptr && source->template isa<T>()) {
|
if (source != nullptr && source->template isa<T>()) {
|
||||||
return static_cast<T*>(source);
|
return static_cast<T *>(source);
|
||||||
} else {
|
} else {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ inline T* cast(U* source) {
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T, typename U,
|
typename T, typename U,
|
||||||
typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type* = nullptr>
|
typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type * = nullptr>
|
||||||
inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) {
|
inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) {
|
||||||
if (r != nullptr && r->template isa<T>()) {
|
if (r != nullptr && r->template isa<T>()) {
|
||||||
return std::static_pointer_cast<T>(r);
|
return std::static_pointer_cast<T>(r);
|
||||||
|
@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager {
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
std::atomic<uint32_t> type_counter{0};
|
std::atomic<uint32_t> type_counter{0};
|
||||||
std::unordered_map<std::string, uint32_t> map;
|
std::unordered_map<std::string, uint32_t> map;
|
||||||
static TypeIdManager* Get();
|
static TypeIdManager *Get();
|
||||||
TypeIdManager() : mutex(), type_counter(0), map() {}
|
TypeIdManager() : mutex(), type_counter(0), map() {}
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -48,11 +48,11 @@ std::string Keyword::ToString() const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Keyword::operator==(const Type& other) const {
|
bool Keyword::operator==(const Type &other) const {
|
||||||
if (!IsSameObjectType(*this, other)) {
|
if (!IsSameObjectType(*this, other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const auto& other_keyword = static_cast<const Keyword&>(other);
|
const auto &other_keyword = static_cast<const Keyword &>(other);
|
||||||
return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_);
|
return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,11 +87,11 @@ std::string Slice::ToString() const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Slice::operator==(const Type& other) const {
|
bool Slice::operator==(const Type &other) const {
|
||||||
if (!IsSameObjectType(*this, other)) {
|
if (!IsSameObjectType(*this, other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto other_slice = static_cast<const Slice&>(other);
|
auto other_slice = static_cast<const Slice &>(other);
|
||||||
return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_);
|
return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,11 +122,11 @@ std::string TensorType::DumpText() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TensorType::operator==(const Type& other) const {
|
bool TensorType::operator==(const Type &other) const {
|
||||||
if (!IsSameObjectType(*this, other)) {
|
if (!IsSameObjectType(*this, other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto other_elem_type = static_cast<const TensorType&>(other).element_type_;
|
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
|
||||||
// When element_type_ = nullptr, which means any type of Array.
|
// When element_type_ = nullptr, which means any type of Array.
|
||||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) {
|
||||||
retval_ = nullptr;
|
retval_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Function::Function(const std::vector<TypePtr>& args, const TypePtr retval)
|
Function::Function(const std::vector<TypePtr> &args, const TypePtr retval)
|
||||||
: Object(kObjectTypeFunction, false), args_(args), retval_(retval) {}
|
: Object(kObjectTypeFunction, false), args_(args), retval_(retval) {}
|
||||||
|
|
||||||
TypePtr Function::DeepCopy() const {
|
TypePtr Function::DeepCopy() const {
|
||||||
|
@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const {
|
||||||
TypePtrList args;
|
TypePtrList args;
|
||||||
TypePtr retval = nullptr;
|
TypePtr retval = nullptr;
|
||||||
(void)std::transform(args_.begin(), args_.end(), std::back_inserter(args),
|
(void)std::transform(args_.begin(), args_.end(), std::back_inserter(args),
|
||||||
[](const TypePtr& arg) { return arg->DeepCopy(); });
|
[](const TypePtr &arg) { return arg->DeepCopy(); });
|
||||||
if (retval_ != nullptr) {
|
if (retval_ != nullptr) {
|
||||||
retval = retval_->DeepCopy();
|
retval = retval_->DeepCopy();
|
||||||
}
|
}
|
||||||
|
@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Function::operator==(const Type& other) const {
|
bool Function::operator==(const Type &other) const {
|
||||||
if (!IsSameObjectType(*this, other)) {
|
if (!IsSameObjectType(*this, other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& other_function = static_cast<const Function&>(other);
|
const auto &other_function = static_cast<const Function &>(other);
|
||||||
if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) {
|
if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) {
|
||||||
if (*retval_ != *other_function.retval_) {
|
if (*retval_ != *other_function.retval_) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -188,7 +188,7 @@ std::string Function::ToString() const {
|
||||||
} else {
|
} else {
|
||||||
buffer << "Func[(";
|
buffer << "Func[(";
|
||||||
bool begin = true;
|
bool begin = true;
|
||||||
for (auto& attr : args_) {
|
for (auto &attr : args_) {
|
||||||
if (!begin) {
|
if (!begin) {
|
||||||
buffer << ", ";
|
buffer << ", ";
|
||||||
} else {
|
} else {
|
||||||
|
@ -242,34 +242,34 @@ std::string JTagged::DumpText() const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem) {
|
std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem) {
|
||||||
MS_EXCEPTION_IF_NULL(problem);
|
MS_EXCEPTION_IF_NULL(problem);
|
||||||
os << problem->ToString();
|
os << problem->ToString();
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t TypeHasher::operator()(TypePtr const& type) const {
|
std::size_t TypeHasher::operator()(TypePtr const &type) const {
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
std::size_t hash = std::hash<size_t>()(type->type_id());
|
std::size_t hash = std::hash<size_t>()(type->type_id());
|
||||||
return hash;
|
return hash;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const {
|
std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
|
||||||
std::size_t hash_sum = 0;
|
std::size_t hash_sum = 0;
|
||||||
for (auto& type : type_list) {
|
for (auto &type : type_list) {
|
||||||
auto type_id = static_cast<std::size_t>(type->type_id());
|
auto type_id = static_cast<std::size_t>(type->type_id());
|
||||||
hash_sum = hash_combine(hash_sum, type_id);
|
hash_sum = hash_combine(hash_sum, type_id);
|
||||||
}
|
}
|
||||||
return hash_sum;
|
return hash_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const {
|
bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
|
||||||
MS_EXCEPTION_IF_NULL(t1);
|
MS_EXCEPTION_IF_NULL(t1);
|
||||||
MS_EXCEPTION_IF_NULL(t2);
|
MS_EXCEPTION_IF_NULL(t2);
|
||||||
return t1->type_id() == t2->type_id();
|
return t1->type_id() == t2->type_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const {
|
bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
|
||||||
if (lhs.size() != rhs.size()) {
|
if (lhs.size() != rhs.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) {
|
TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
|
||||||
TypePtr type = nullptr;
|
TypePtr type = nullptr;
|
||||||
if (type_name == num_type_name) {
|
if (type_name == num_type_name) {
|
||||||
type = std::make_shared<T>();
|
type = std::make_shared<T>();
|
||||||
|
@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_
|
||||||
}
|
}
|
||||||
auto bits = std::stoi(type_name.substr(num_type_name.size()));
|
auto bits = std::stoi(type_name.substr(num_type_name.size()));
|
||||||
type = std::make_shared<T>(bits);
|
type = std::make_shared<T>(bits);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what();
|
MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) {
|
std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
|
||||||
std::vector<TypePtr> types;
|
std::vector<TypePtr> types;
|
||||||
if (type_names.length() == 0) {
|
if (type_names.length() == 0) {
|
||||||
return types;
|
return types;
|
||||||
|
@ -371,7 +371,7 @@ std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) {
|
||||||
return types;
|
return types;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr TensorStrToType(const std::string& type_name) {
|
TypePtr TensorStrToType(const std::string &type_name) {
|
||||||
TypePtr type = nullptr;
|
TypePtr type = nullptr;
|
||||||
if (type_name == "Tensor") {
|
if (type_name == "Tensor") {
|
||||||
type = std::make_shared<TensorType>();
|
type = std::make_shared<TensorType>();
|
||||||
|
@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
type = std::make_shared<TensorType>(element_type);
|
type = std::make_shared<TensorType>(element_type);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
|
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr ListStrToType(const std::string& type_name) {
|
TypePtr ListStrToType(const std::string &type_name) {
|
||||||
TypePtr type = nullptr;
|
TypePtr type = nullptr;
|
||||||
if (type_name == "List") {
|
if (type_name == "List") {
|
||||||
type = std::make_shared<List>();
|
type = std::make_shared<List>();
|
||||||
|
@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) {
|
||||||
std::string element_strs = type_name.substr(start, end - start);
|
std::string element_strs = type_name.substr(start, end - start);
|
||||||
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
|
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
|
||||||
bool wrong =
|
bool wrong =
|
||||||
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; });
|
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||||
if (wrong) {
|
if (wrong) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
type = std::make_shared<List>(element_types);
|
type = std::make_shared<List>(element_types);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
|
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) {
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr TupleStrToType(const std::string& type_name) {
|
TypePtr TupleStrToType(const std::string &type_name) {
|
||||||
TypePtr type = nullptr;
|
TypePtr type = nullptr;
|
||||||
if (type_name == "Tuple") {
|
if (type_name == "Tuple") {
|
||||||
type = std::make_shared<Tuple>();
|
type = std::make_shared<Tuple>();
|
||||||
|
@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) {
|
||||||
std::string element_strs = type_name.substr(start, end - start);
|
std::string element_strs = type_name.substr(start, end - start);
|
||||||
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
|
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
|
||||||
bool wrong =
|
bool wrong =
|
||||||
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; });
|
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||||
if (wrong) {
|
if (wrong) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
type = std::make_shared<Tuple>(element_types);
|
type = std::make_shared<Tuple>(element_types);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
|
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr FunctionStrToType(const std::string& type_name) {
|
TypePtr FunctionStrToType(const std::string &type_name) {
|
||||||
TypePtr type = nullptr;
|
TypePtr type = nullptr;
|
||||||
|
|
||||||
if (type_name == "Function") {
|
if (type_name == "Function") {
|
||||||
|
@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) {
|
||||||
|
|
||||||
std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
|
std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
|
||||||
TypePtr retval = StringToType(str_retval);
|
TypePtr retval = StringToType(str_retval);
|
||||||
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; });
|
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||||
if (retval == nullptr || wrong) {
|
if (retval == nullptr || wrong) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
type = std::make_shared<Function>(args_type, retval);
|
type = std::make_shared<Function>(args_type, retval);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
|
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) {
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TypePtr StringToType(const std::string& type_name) {
|
TypePtr StringToType(const std::string &type_name) {
|
||||||
TypePtr type = nullptr;
|
TypePtr type = nullptr;
|
||||||
if (type_name.compare("None") == 0) {
|
if (type_name.compare("None") == 0) {
|
||||||
type = std::make_shared<TypeNone>();
|
type = std::make_shared<TypeNone>();
|
||||||
|
@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) {
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) {
|
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
||||||
if (x == nullptr || base_type == nullptr) {
|
if (x == nullptr || base_type == nullptr) {
|
||||||
MS_LOG(ERROR) << "Type is nullptr.";
|
MS_LOG(ERROR) << "Type is nullptr.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsSubType(TypePtr const& t1, TypePtr const& t2) {
|
bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
|
||||||
MS_EXCEPTION_IF_NULL(t1);
|
MS_EXCEPTION_IF_NULL(t1);
|
||||||
if (t1->type_id() == kTypeUnknown) {
|
if (t1->type_id() == kTypeUnknown) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(
|
REGISTER_PYBIND_DEFINE(
|
||||||
typing, ([](py::module* const m) {
|
typing, ([](py::module *const m) {
|
||||||
auto m_sub = m->def_submodule("typing", "submodule for dtype");
|
auto m_sub = m->def_submodule("typing", "submodule for dtype");
|
||||||
py::enum_<TypeId>(m_sub, "TypeId");
|
py::enum_<TypeId>(m_sub, "TypeId");
|
||||||
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
|
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
|
||||||
(void)m_sub.def("load_type", &TypeIdToType, "load type");
|
(void)m_sub.def("load_type", &TypeIdToType, "load type");
|
||||||
(void)m_sub.def(
|
(void)m_sub.def(
|
||||||
"dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type");
|
"dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
|
||||||
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
|
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
|
||||||
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
|
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
|
||||||
.def("__eq__",
|
.def("__eq__",
|
||||||
[](const TypePtr& t1, const TypePtr& t2) {
|
[](const TypePtr &t1, const TypePtr &t2) {
|
||||||
if (t1 != nullptr && t2 != nullptr) {
|
if (t1 != nullptr && t2 != nullptr) {
|
||||||
return *t1 == *t2;
|
return *t1 == *t2;
|
||||||
}
|
}
|
||||||
|
@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE(
|
||||||
.def("__hash__", &Type::hash)
|
.def("__hash__", &Type::hash)
|
||||||
.def("__str__", &Type::ToString)
|
.def("__str__", &Type::ToString)
|
||||||
.def("__repr__", &Type::ReprString)
|
.def("__repr__", &Type::ReprString)
|
||||||
.def("__deepcopy__", [](const TypePtr& t, py::dict) {
|
.def("__deepcopy__", [](const TypePtr &t, py::dict) {
|
||||||
if (t == nullptr) {
|
if (t == nullptr) {
|
||||||
return static_cast<TypePtr>(nullptr);
|
return static_cast<TypePtr>(nullptr);
|
||||||
}
|
}
|
||||||
|
@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE(
|
||||||
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
|
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const Bool&) { // __getstate__
|
[](const Bool &) { // __getstate__
|
||||||
return py::make_tuple();
|
return py::make_tuple();
|
||||||
},
|
},
|
||||||
[](const py::tuple&) { // __setstate__
|
[](const py::tuple &) { // __setstate__
|
||||||
return std::make_shared<Bool>();
|
return std::make_shared<Bool>();
|
||||||
}));
|
}));
|
||||||
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
|
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def(py::init<int>(), py::arg("nbits"))
|
.def(py::init<int>(), py::arg("nbits"))
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const Int& t) { // __getstate__
|
[](const Int &t) { // __getstate__
|
||||||
/* Return a tuple that fully encodes the state of the object */
|
/* Return a tuple that fully encodes the state of the object */
|
||||||
return py::make_tuple(py::int_(t.nbits()));
|
return py::make_tuple(py::int_(t.nbits()));
|
||||||
},
|
},
|
||||||
[](const py::tuple& t) { // __setstate__
|
[](const py::tuple &t) { // __setstate__
|
||||||
if (t.size() != 1) {
|
if (t.size() != 1) {
|
||||||
throw std::runtime_error("Invalid state!");
|
throw std::runtime_error("Invalid state!");
|
||||||
}
|
}
|
||||||
|
@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE(
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def(py::init<int>(), py::arg("nbits"))
|
.def(py::init<int>(), py::arg("nbits"))
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const UInt& t) { // __getstate__
|
[](const UInt &t) { // __getstate__
|
||||||
/* Return a tuple that fully encodes the state of the object */
|
/* Return a tuple that fully encodes the state of the object */
|
||||||
return py::make_tuple(py::int_(t.nbits()));
|
return py::make_tuple(py::int_(t.nbits()));
|
||||||
},
|
},
|
||||||
[](const py::tuple& t) { // __setstate__
|
[](const py::tuple &t) { // __setstate__
|
||||||
if (t.size() != 1) {
|
if (t.size() != 1) {
|
||||||
throw std::runtime_error("Invalid state!");
|
throw std::runtime_error("Invalid state!");
|
||||||
}
|
}
|
||||||
|
@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE(
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def(py::init<int>(), py::arg("nbits"))
|
.def(py::init<int>(), py::arg("nbits"))
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const Float& t) { // __getstate__
|
[](const Float &t) { // __getstate__
|
||||||
/* Return a tuple that fully encodes the state of the object */
|
/* Return a tuple that fully encodes the state of the object */
|
||||||
return py::make_tuple(py::int_(t.nbits()));
|
return py::make_tuple(py::int_(t.nbits()));
|
||||||
},
|
},
|
||||||
[](const py::tuple& t) { // __setstate__
|
[](const py::tuple &t) { // __setstate__
|
||||||
if (t.size() != 1) {
|
if (t.size() != 1) {
|
||||||
throw std::runtime_error("Invalid state!");
|
throw std::runtime_error("Invalid state!");
|
||||||
}
|
}
|
||||||
|
@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE(
|
||||||
.def(py::init<TypePtr>(), py::arg("element"))
|
.def(py::init<TypePtr>(), py::arg("element"))
|
||||||
.def("element_type", &TensorType::element)
|
.def("element_type", &TensorType::element)
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const TensorType& t) { // __getstate__
|
[](const TensorType &t) { // __getstate__
|
||||||
/* Return a tuple that fully encodes the state of the object */
|
/* Return a tuple that fully encodes the state of the object */
|
||||||
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
|
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
|
||||||
},
|
},
|
||||||
[](const py::tuple& t) { // __setstate__
|
[](const py::tuple &t) { // __setstate__
|
||||||
if (t.size() != 1) {
|
if (t.size() != 1) {
|
||||||
throw std::runtime_error("Invalid state!");
|
throw std::runtime_error("Invalid state!");
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr<String>;
|
||||||
class Keyword : public Object {
|
class Keyword : public Object {
|
||||||
public:
|
public:
|
||||||
Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {}
|
Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {}
|
||||||
Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
|
Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
|
||||||
|
|
||||||
~Keyword() override = default;
|
~Keyword() override = default;
|
||||||
MS_DECLARE_PARENT(Keyword, Object)
|
MS_DECLARE_PARENT(Keyword, Object)
|
||||||
|
@ -70,7 +70,7 @@ class Keyword : public Object {
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::string DumpText() const override;
|
std::string DumpText() const override;
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
std::string GetKey() const { return key_; }
|
std::string GetKey() const { return key_; }
|
||||||
TypePtr GetValue() const { return value_; }
|
TypePtr GetValue() const { return value_; }
|
||||||
|
@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr<Keyword>;
|
||||||
class Slice : public Object {
|
class Slice : public Object {
|
||||||
public:
|
public:
|
||||||
Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {}
|
Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {}
|
||||||
Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step)
|
Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step)
|
||||||
: Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {}
|
: Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {}
|
||||||
|
|
||||||
~Slice() override = default;
|
~Slice() override = default;
|
||||||
|
@ -95,7 +95,7 @@ class Slice : public Object {
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::string DumpText() const override;
|
std::string DumpText() const override;
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
TypePtr get_start() const { return start_; }
|
TypePtr get_start() const { return start_; }
|
||||||
TypePtr get_stop() const { return stop_; }
|
TypePtr get_stop() const { return stop_; }
|
||||||
|
@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr<Slice>;
|
||||||
class TensorType : public Object {
|
class TensorType : public Object {
|
||||||
public:
|
public:
|
||||||
TensorType() : Object(kObjectTypeTensorType) {}
|
TensorType() : Object(kObjectTypeTensorType) {}
|
||||||
explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
|
explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
|
||||||
~TensorType() override = default;
|
~TensorType() override = default;
|
||||||
MS_DECLARE_PARENT(TensorType, Object)
|
MS_DECLARE_PARENT(TensorType, Object)
|
||||||
|
|
||||||
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
|
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
|
||||||
const TypePtr element() const { return element_type_; }
|
const TypePtr element() const { return element_type_; }
|
||||||
void set_element(const TypePtr& element_type) { element_type_ = element_type; }
|
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||||
|
|
||||||
TypePtr DeepCopy() const override;
|
TypePtr DeepCopy() const override;
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::string ToReprString() const override { return "tensor"; }
|
std::string ToReprString() const override { return "tensor"; }
|
||||||
std::string DumpText() const override;
|
std::string DumpText() const override;
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TypePtr element_type_;
|
TypePtr element_type_;
|
||||||
|
@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr<TensorType>;
|
||||||
class Function : public Object {
|
class Function : public Object {
|
||||||
public:
|
public:
|
||||||
Function();
|
Function();
|
||||||
Function(const std::vector<TypePtr>& args, const TypePtr retval);
|
Function(const std::vector<TypePtr> &args, const TypePtr retval);
|
||||||
~Function() override = default;
|
~Function() override = default;
|
||||||
MS_DECLARE_PARENT(Function, Object)
|
MS_DECLARE_PARENT(Function, Object)
|
||||||
|
|
||||||
|
@ -141,11 +141,11 @@ class Function : public Object {
|
||||||
|
|
||||||
// Add temporarily for return abstraction to avoid type checking.
|
// Add temporarily for return abstraction to avoid type checking.
|
||||||
bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); }
|
bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); }
|
||||||
const std::vector<TypePtr>& args() const { return args_; }
|
const std::vector<TypePtr> &args() const { return args_; }
|
||||||
const TypePtr& retval() const { return retval_; }
|
const TypePtr &retval() const { return retval_; }
|
||||||
|
|
||||||
TypePtr DeepCopy() const override;
|
TypePtr DeepCopy() const override;
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::string ToReprString() const override { return "function"; }
|
std::string ToReprString() const override { return "function"; }
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr<Function>;
|
||||||
class JTagged : public Object {
|
class JTagged : public Object {
|
||||||
public:
|
public:
|
||||||
JTagged() : Object(kObjectTypeJTagged) {}
|
JTagged() : Object(kObjectTypeJTagged) {}
|
||||||
explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
|
explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
|
||||||
~JTagged() override = default;
|
~JTagged() override = default;
|
||||||
MS_DECLARE_PARENT(JTagged, Object)
|
MS_DECLARE_PARENT(JTagged, Object)
|
||||||
|
|
||||||
|
@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr<TypeType>;
|
||||||
class Problem : public Type {
|
class Problem : public Type {
|
||||||
public:
|
public:
|
||||||
Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {}
|
Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {}
|
||||||
explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {}
|
explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {}
|
||||||
~Problem() override = default;
|
~Problem() override = default;
|
||||||
MS_DECLARE_PARENT(Problem, Type)
|
MS_DECLARE_PARENT(Problem, Type)
|
||||||
|
|
||||||
|
@ -222,7 +222,7 @@ class Problem : public Type {
|
||||||
std::string ToString() const override { return kind_.name(); }
|
std::string ToString() const override { return kind_.name(); }
|
||||||
std::string DumpText() const override { return "ProblemType"; }
|
std::string DumpText() const override { return "ProblemType"; }
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem);
|
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Named kind_;
|
Named kind_;
|
||||||
|
@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr<External>;
|
||||||
|
|
||||||
// helper template
|
// helper template
|
||||||
template <class T>
|
template <class T>
|
||||||
TypePtr Clone(const T& t) {
|
TypePtr Clone(const T &t) {
|
||||||
return t.Clone();
|
return t.Clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr StringToType(const std::string& type_name);
|
TypePtr StringToType(const std::string &type_name);
|
||||||
|
|
||||||
// Judge whether x is predicate or is a subclass of predicate.
|
// Judge whether x is predicate or is a subclass of predicate.
|
||||||
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type);
|
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
|
||||||
|
|
||||||
// Whether t1 is identity or a subclass of t2.
|
// Whether t1 is identity or a subclass of t2.
|
||||||
bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr);
|
bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
|
||||||
|
|
||||||
struct TypeHasher {
|
struct TypeHasher {
|
||||||
std::size_t operator()(TypePtr const& type) const;
|
std::size_t operator()(TypePtr const &type) const;
|
||||||
};
|
};
|
||||||
struct TypeListHasher {
|
struct TypeListHasher {
|
||||||
std::size_t operator()(const TypePtrList& type_list) const;
|
std::size_t operator()(const TypePtrList &type_list) const;
|
||||||
};
|
};
|
||||||
struct TypeEqual {
|
struct TypeEqual {
|
||||||
bool operator()(TypePtr const& t1, TypePtr const& t2) const;
|
bool operator()(TypePtr const &t1, TypePtr const &t2) const;
|
||||||
};
|
};
|
||||||
struct TypeListEqual {
|
struct TypeListEqual {
|
||||||
bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const;
|
bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern const TypePtr kTypeExternal;
|
extern const TypePtr kTypeExternal;
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include "pybind_api/export_flags.h"
|
#include "pybind_api/export_flags.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
static std::string DumpTypeVector(const std::vector<TypePtr>& elements, bool is_dumptext) {
|
static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
bool begin = true;
|
bool begin = true;
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const {
|
||||||
} else {
|
} else {
|
||||||
TypePtrList elements;
|
TypePtrList elements;
|
||||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
|
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
|
||||||
[](const TypePtr& ele) { return ele->DeepCopy(); });
|
[](const TypePtr &ele) { return ele->DeepCopy(); });
|
||||||
auto copy = std::make_shared<List>(elements);
|
auto copy = std::make_shared<List>(elements);
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const {
|
||||||
return elements_[dim];
|
return elements_[dim];
|
||||||
}
|
}
|
||||||
|
|
||||||
bool List::operator==(const Type& other) const {
|
bool List::operator==(const Type &other) const {
|
||||||
if (!IsSameObjectType(*this, other)) {
|
if (!IsSameObjectType(*this, other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const List& other_list = static_cast<const List&>(other);
|
const List &other_list = static_cast<const List &>(other);
|
||||||
if (elements_.size() != other_list.elements_.size()) {
|
if (elements_.size() != other_list.elements_.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
Class::Class(const Named& tag, const ClassAttrVector& attributes,
|
Class::Class(const Named &tag, const ClassAttrVector &attributes,
|
||||||
const std::unordered_map<std::string, ValuePtr>& methods)
|
const std::unordered_map<std::string, ValuePtr> &methods)
|
||||||
: Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {}
|
: Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {}
|
||||||
|
|
||||||
std::string List::ToString() const {
|
std::string List::ToString() const {
|
||||||
|
@ -122,7 +122,7 @@ std::string List::DumpText() const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Class::operator==(const Type& other) const {
|
bool Class::operator==(const Type &other) const {
|
||||||
// Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj.
|
// Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj.
|
||||||
return &other == this;
|
return &other == this;
|
||||||
}
|
}
|
||||||
|
@ -143,7 +143,7 @@ std::string Class::ToString() const {
|
||||||
} else {
|
} else {
|
||||||
bool begin = true;
|
bool begin = true;
|
||||||
buffer << "cls." << tag_ << "[";
|
buffer << "cls." << tag_ << "[";
|
||||||
for (auto& attr : attributes_) {
|
for (auto &attr : attributes_) {
|
||||||
if (!begin) {
|
if (!begin) {
|
||||||
buffer << ", ";
|
buffer << ", ";
|
||||||
} else {
|
} else {
|
||||||
|
@ -163,7 +163,7 @@ std::string Class::DumpText() const {
|
||||||
} else {
|
} else {
|
||||||
bool begin = true;
|
bool begin = true;
|
||||||
buffer << "Cls." << tag_ << "[";
|
buffer << "Cls." << tag_ << "[";
|
||||||
for (auto& attr : attributes_) {
|
for (auto &attr : attributes_) {
|
||||||
if (!begin) {
|
if (!begin) {
|
||||||
buffer << ", ";
|
buffer << ", ";
|
||||||
} else {
|
} else {
|
||||||
|
@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const {
|
||||||
} else {
|
} else {
|
||||||
TypePtrList elements;
|
TypePtrList elements;
|
||||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
|
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
|
||||||
[](const TypePtr& ele) { return ele->DeepCopy(); });
|
[](const TypePtr &ele) { return ele->DeepCopy(); });
|
||||||
auto copy = std::make_shared<Tuple>(elements);
|
auto copy = std::make_shared<Tuple>(elements);
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Tuple::operator==(const Type& other) const {
|
bool Tuple::operator==(const Type &other) const {
|
||||||
if (!IsSameObjectType(*this, other)) {
|
if (!IsSameObjectType(*this, other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto other_tuple = static_cast<const Tuple&>(other);
|
auto other_tuple = static_cast<const Tuple &>(other);
|
||||||
if (elements_.size() != other_tuple.elements_.size()) {
|
if (elements_.size() != other_tuple.elements_.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const {
|
||||||
std::vector<std::pair<std::string, TypePtr>> kv;
|
std::vector<std::pair<std::string, TypePtr>> kv;
|
||||||
(void)std::transform(
|
(void)std::transform(
|
||||||
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
|
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
|
||||||
[](const std::pair<std::string, TypePtr>& item) { return std::make_pair(item.first, item.second->DeepCopy()); });
|
[](const std::pair<std::string, TypePtr> &item) { return std::make_pair(item.first, item.second->DeepCopy()); });
|
||||||
return std::make_shared<Dictionary>(kv);
|
return std::make_shared<Dictionary>(kv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -259,7 +259,7 @@ std::string Dictionary::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
std::vector<std::string> keys;
|
std::vector<std::string> keys;
|
||||||
std::vector<TypePtr> values;
|
std::vector<TypePtr> values;
|
||||||
for (const auto& kv : key_values_) {
|
for (const auto &kv : key_values_) {
|
||||||
keys.push_back(kv.first);
|
keys.push_back(kv.first);
|
||||||
values.push_back(kv.second);
|
values.push_back(kv.second);
|
||||||
}
|
}
|
||||||
|
@ -276,12 +276,12 @@ std::string Dictionary::ToString() const {
|
||||||
|
|
||||||
std::string Dictionary::DumpText() const { return ToString(); }
|
std::string Dictionary::DumpText() const { return ToString(); }
|
||||||
|
|
||||||
bool Dictionary::operator==(const mindspore::Type& other) const {
|
bool Dictionary::operator==(const mindspore::Type &other) const {
|
||||||
if (!IsSameObjectType(*this, other)) {
|
if (!IsSameObjectType(*this, other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& other_dict = static_cast<const Dictionary&>(other);
|
const auto &other_dict = static_cast<const Dictionary &>(other);
|
||||||
if (key_values_.size() != other_dict.key_values_.size()) {
|
if (key_values_.size() != other_dict.key_values_.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,10 +40,10 @@ namespace mindspore {
|
||||||
class List : public Object {
|
class List : public Object {
|
||||||
public:
|
public:
|
||||||
List() : Object(kObjectTypeList) {}
|
List() : Object(kObjectTypeList) {}
|
||||||
List(const std::initializer_list<TypePtr>& objs)
|
List(const std::initializer_list<TypePtr> &objs)
|
||||||
: Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {}
|
: Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {}
|
||||||
// Shadow copy;
|
// Shadow copy;
|
||||||
explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {}
|
explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {}
|
||||||
~List() override {}
|
~List() override {}
|
||||||
MS_DECLARE_PARENT(List, Object)
|
MS_DECLARE_PARENT(List, Object)
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ class List : public Object {
|
||||||
TypeId generic_type_id() const override { return kObjectTypeList; }
|
TypeId generic_type_id() const override { return kObjectTypeList; }
|
||||||
TypePtr DeepCopy() const override;
|
TypePtr DeepCopy() const override;
|
||||||
|
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
std::size_t size() const { return elements_.size(); }
|
std::size_t size() const { return elements_.size(); }
|
||||||
TypePtrList elements() const { return elements_; }
|
TypePtrList elements() const { return elements_; }
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
@ -68,22 +68,22 @@ using ClassAttrVector = std::vector<std::pair<std::string, TypePtr>>;
|
||||||
class Class : public Object {
|
class Class : public Object {
|
||||||
public:
|
public:
|
||||||
Class() : Object(kObjectTypeClass), tag_(Named("Class")) {}
|
Class() : Object(kObjectTypeClass), tag_(Named("Class")) {}
|
||||||
Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map<std::string, ValuePtr>& methods);
|
Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map<std::string, ValuePtr> &methods);
|
||||||
~Class() override {}
|
~Class() override {}
|
||||||
MS_DECLARE_PARENT(Class, Object)
|
MS_DECLARE_PARENT(Class, Object)
|
||||||
|
|
||||||
TypeId generic_type_id() const override { return kObjectTypeClass; }
|
TypeId generic_type_id() const override { return kObjectTypeClass; }
|
||||||
|
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
TypePtr DeepCopy() const override;
|
TypePtr DeepCopy() const override;
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::string DumpText() const override;
|
std::string DumpText() const override;
|
||||||
void set_value(const std::unordered_map<std::string, ValuePtr>& v) { attributes_value_ = v; }
|
void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; }
|
||||||
|
|
||||||
Named tag() { return tag_; }
|
Named tag() { return tag_; }
|
||||||
std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; }
|
std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; }
|
||||||
std::unordered_map<std::string, ValuePtr> methods() { return methods_; }
|
std::unordered_map<std::string, ValuePtr> methods() { return methods_; }
|
||||||
ClassAttrVector& GetAttributes() { return attributes_; }
|
ClassAttrVector &GetAttributes() { return attributes_; }
|
||||||
|
|
||||||
ClassAttrVector attributes_;
|
ClassAttrVector attributes_;
|
||||||
|
|
||||||
|
@ -99,11 +99,11 @@ class Tuple : public Object {
|
||||||
public:
|
public:
|
||||||
Tuple() : Object(kObjectTypeTuple) {}
|
Tuple() : Object(kObjectTypeTuple) {}
|
||||||
// usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)};
|
// usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)};
|
||||||
Tuple(const std::initializer_list<TypePtr>& objs)
|
Tuple(const std::initializer_list<TypePtr> &objs)
|
||||||
: Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
|
: Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
|
||||||
|
|
||||||
// Shadow copy
|
// Shadow copy
|
||||||
explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
|
explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
|
||||||
|
|
||||||
~Tuple() override {}
|
~Tuple() override {}
|
||||||
MS_DECLARE_PARENT(Tuple, Object)
|
MS_DECLARE_PARENT(Tuple, Object)
|
||||||
|
@ -115,7 +115,7 @@ class Tuple : public Object {
|
||||||
std::string ToReprString() const override { return "tuple_"; }
|
std::string ToReprString() const override { return "tuple_"; }
|
||||||
std::string DumpText() const override;
|
std::string DumpText() const override;
|
||||||
const TypePtr operator[](size_t dim) const;
|
const TypePtr operator[](size_t dim) const;
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
TypePtrList elements() const { return elements_; }
|
TypePtrList elements() const { return elements_; }
|
||||||
std::size_t size() const { return elements_.size(); }
|
std::size_t size() const { return elements_.size(); }
|
||||||
|
@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr<Tuple>;
|
||||||
class Dictionary : public Object {
|
class Dictionary : public Object {
|
||||||
public:
|
public:
|
||||||
Dictionary() : Object(kObjectTypeDictionary) {}
|
Dictionary() : Object(kObjectTypeDictionary) {}
|
||||||
explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>>& key_values)
|
explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>> &key_values)
|
||||||
: Object(kObjectTypeDictionary, false), key_values_(key_values) {}
|
: Object(kObjectTypeDictionary, false), key_values_(key_values) {}
|
||||||
|
|
||||||
~Dictionary() override {}
|
~Dictionary() override {}
|
||||||
|
@ -136,7 +136,7 @@ class Dictionary : public Object {
|
||||||
|
|
||||||
TypeId generic_type_id() const override { return kObjectTypeDictionary; }
|
TypeId generic_type_id() const override { return kObjectTypeDictionary; }
|
||||||
|
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
TypePtr DeepCopy() const override;
|
TypePtr DeepCopy() const override;
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::string DumpText() const override;
|
std::string DumpText() const override;
|
||||||
|
|
|
@ -24,11 +24,11 @@
|
||||||
#include "pybind_api/export_flags.h"
|
#include "pybind_api/export_flags.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
bool Number::operator==(const Type& other) const {
|
bool Number::operator==(const Type &other) const {
|
||||||
if (!IsSameObjectType(*this, other)) {
|
if (!IsSameObjectType(*this, other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto other_number = static_cast<const Number&>(other);
|
auto other_number = static_cast<const Number &>(other);
|
||||||
return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_));
|
return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,12 +49,12 @@ class Number : public Object {
|
||||||
TypeId type_id() const override { return number_type_; }
|
TypeId type_id() const override { return number_type_; }
|
||||||
TypeId generic_type_id() const override { return kObjectTypeNumber; }
|
TypeId generic_type_id() const override { return kObjectTypeNumber; }
|
||||||
|
|
||||||
bool operator==(const Type& other) const override;
|
bool operator==(const Type &other) const override;
|
||||||
TypePtr DeepCopy() const override { return std::make_shared<Number>(); }
|
TypePtr DeepCopy() const override { return std::make_shared<Number>(); }
|
||||||
std::string ToString() const override { return "Number"; }
|
std::string ToString() const override { return "Number"; }
|
||||||
std::string ToReprString() const override { return "number"; }
|
std::string ToReprString() const override { return "number"; }
|
||||||
std::string DumpText() const override { return "Number"; }
|
std::string DumpText() const override { return "Number"; }
|
||||||
std::string GetTypeName(const std::string& type_name) const {
|
std::string GetTypeName(const std::string &type_name) const {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << type_name;
|
oss << type_name;
|
||||||
if (nbits() != 0) {
|
if (nbits() != 0) {
|
||||||
|
|
|
@ -51,7 +51,7 @@ class RefKeyType : public Object {
|
||||||
class RefType : public Object {
|
class RefType : public Object {
|
||||||
public:
|
public:
|
||||||
RefType() : Object(kObjectTypeRef) {}
|
RefType() : Object(kObjectTypeRef) {}
|
||||||
RefType(const TypePtr& subtype, const TypePtr& subtype_origin)
|
RefType(const TypePtr &subtype, const TypePtr &subtype_origin)
|
||||||
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
|
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
|
||||||
~RefType() override {}
|
~RefType() override {}
|
||||||
MS_DECLARE_PARENT(RefType, Object)
|
MS_DECLARE_PARENT(RefType, Object)
|
||||||
|
|
|
@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* MetaIdLabel(const TypeId& v) {
|
const char *MetaIdLabel(const TypeId &v) {
|
||||||
switch (v) {
|
switch (v) {
|
||||||
case kTypeUnknown:
|
case kTypeUnknown:
|
||||||
return "kTypeUnknown";
|
return "kTypeUnknown";
|
||||||
|
@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* ObjectIdLabel(const TypeId& v) {
|
const char *ObjectIdLabel(const TypeId &v) {
|
||||||
switch (v) {
|
switch (v) {
|
||||||
case kObjectTypeNumber:
|
case kObjectTypeNumber:
|
||||||
return "kObjectTypeNumber";
|
return "kObjectTypeNumber";
|
||||||
|
@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* NumberIdLabel(const TypeId& v) {
|
const char *NumberIdLabel(const TypeId &v) {
|
||||||
switch (v) {
|
switch (v) {
|
||||||
case kNumberTypeBool:
|
case kNumberTypeBool:
|
||||||
return "kNumberTypeBool";
|
return "kNumberTypeBool";
|
||||||
|
@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* TypeIdLabel(const TypeId& v) {
|
const char *TypeIdLabel(const TypeId &v) {
|
||||||
if (v < kMetaTypeEnd) {
|
if (v < kMetaTypeEnd) {
|
||||||
return MetaIdLabel(v);
|
return MetaIdLabel(v);
|
||||||
} else {
|
} else {
|
||||||
|
@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsSameObjectType(const Type& lhs, const Type& rhs) {
|
bool IsSameObjectType(const Type &lhs, const Type &rhs) {
|
||||||
if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
|
if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return lhs.object_type() == rhs.object_type();
|
return lhs.object_type() == rhs.object_type();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GetTypeByte(const TypePtr& type_ptr) {
|
size_t GetTypeByte(const TypePtr &type_ptr) {
|
||||||
if (type_ptr && type_ptr->isa<Number>()) {
|
if (type_ptr && type_ptr->isa<Number>()) {
|
||||||
auto number = dyn_cast<Number>(type_ptr);
|
auto number = dyn_cast<Number>(type_ptr);
|
||||||
if (!number) {
|
if (!number) {
|
||||||
|
@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Type::operator==(const Value& other) const {
|
bool Type::operator==(const Value &other) const {
|
||||||
if (other.isa<Type>()) {
|
if (other.isa<Type>()) {
|
||||||
auto other_type = static_cast<const Type*>(&other);
|
auto other_type = static_cast<const Type *>(&other);
|
||||||
return *this == *other_type;
|
return *this == *other_type;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() {
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Type& type) {
|
std::ostream &operator<<(std::ostream &os, const Type &type) {
|
||||||
os << type.ToString();
|
os << type.ToString();
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const TypePtr type) {
|
std::ostream &operator<<(std::ostream &os, const TypePtr type) {
|
||||||
os << type->ToString();
|
os << type->ToString();
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Object& obj) {
|
std::ostream &operator<<(std::ostream &os, const Object &obj) {
|
||||||
os << obj.ToString();
|
os << obj.ToString();
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj) {
|
std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
|
||||||
os << obj->ToString();
|
os << obj->ToString();
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const TypePtrList& types) {
|
std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
|
||||||
os << "[";
|
os << "[";
|
||||||
for (size_t i = 0; i < types.size(); ++i) {
|
for (size_t i = 0; i < types.size(); ++i) {
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
|
|
|
@ -95,10 +95,10 @@ enum TypeId : int {
|
||||||
TypeId IntBitsToTypeId(const int nbits);
|
TypeId IntBitsToTypeId(const int nbits);
|
||||||
TypeId UIntBitsToTypeId(const int nbits);
|
TypeId UIntBitsToTypeId(const int nbits);
|
||||||
TypeId FloatBitsToTypeId(const int nbits);
|
TypeId FloatBitsToTypeId(const int nbits);
|
||||||
const char* TypeIdLabel(const TypeId& v);
|
const char *TypeIdLabel(const TypeId &v);
|
||||||
TypeId NormalizeTypeId(const TypeId type_id);
|
TypeId NormalizeTypeId(const TypeId type_id);
|
||||||
bool IsSameObjectType(const Type& lhs, const Type& rhs);
|
bool IsSameObjectType(const Type &lhs, const Type &rhs);
|
||||||
size_t GetTypeByte(const TypePtr& type_ptr);
|
size_t GetTypeByte(const TypePtr &type_ptr);
|
||||||
|
|
||||||
// Base class for all types
|
// Base class for all types
|
||||||
// forward declaration.
|
// forward declaration.
|
||||||
|
@ -110,14 +110,14 @@ class Type : public Value {
|
||||||
~Type() override = default;
|
~Type() override = default;
|
||||||
MS_DECLARE_PARENT(Type, Value)
|
MS_DECLARE_PARENT(Type, Value)
|
||||||
|
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
TypeId meta_type() const { return meta_type_; }
|
TypeId meta_type() const { return meta_type_; }
|
||||||
|
|
||||||
virtual TypeId type_id() const { return meta_type_; }
|
virtual TypeId type_id() const { return meta_type_; }
|
||||||
virtual TypeId generic_type_id() const { return kMetaTypeType; }
|
virtual TypeId generic_type_id() const { return kMetaTypeType; }
|
||||||
|
|
||||||
virtual bool operator!=(const Type& other) const { return !(*this == other); }
|
virtual bool operator!=(const Type &other) const { return !(*this == other); }
|
||||||
virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); }
|
virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); }
|
||||||
virtual bool equal(const TypePtr other) const { return *this == *other; }
|
virtual bool equal(const TypePtr other) const { return *this == *other; }
|
||||||
|
|
||||||
virtual TypeId object_type() const { return kTypeUnknown; }
|
virtual TypeId object_type() const { return kTypeUnknown; }
|
||||||
|
@ -134,8 +134,8 @@ class Type : public Value {
|
||||||
bool IsUnknown() const { return (meta_type_ == kMetaTypeType); }
|
bool IsUnknown() const { return (meta_type_ == kMetaTypeType); }
|
||||||
bool IsGeneric() const { return is_generic_; }
|
bool IsGeneric() const { return is_generic_; }
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
friend std::ostream& operator<<(std::ostream& os, const Type& type);
|
friend std::ostream &operator<<(std::ostream &os, const Type &type);
|
||||||
friend std::ostream& operator<<(std::ostream& os, const TypePtr type);
|
friend std::ostream &operator<<(std::ostream &os, const TypePtr type);
|
||||||
|
|
||||||
const bool parse_info_ = true;
|
const bool parse_info_ = true;
|
||||||
|
|
||||||
|
@ -163,14 +163,14 @@ class Object : public Type {
|
||||||
bool equal(const TypePtr other) const override;
|
bool equal(const TypePtr other) const override;
|
||||||
std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); }
|
std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); }
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const Object& obj);
|
friend std::ostream &operator<<(std::ostream &os, const Object &obj);
|
||||||
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj);
|
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const TypeId object_type_;
|
const TypeId object_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const TypePtrList& types);
|
std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_
|
#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_
|
||||||
|
|
|
@ -61,7 +61,7 @@ FuncGraph::FuncGraph()
|
||||||
AbstractFunctionPtr FuncGraph::abstract() {
|
AbstractFunctionPtr FuncGraph::abstract() {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
|
|
||||||
for (auto& p : parameters_) {
|
for (auto &p : parameters_) {
|
||||||
MS_EXCEPTION_IF_NULL(p);
|
MS_EXCEPTION_IF_NULL(p);
|
||||||
if (p->abstract() == nullptr) {
|
if (p->abstract() == nullptr) {
|
||||||
MS_LOG(ERROR) << "Error!!";
|
MS_LOG(ERROR) << "Error!!";
|
||||||
|
@ -78,7 +78,7 @@ AbstractFunctionPtr FuncGraph::abstract() {
|
||||||
return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract());
|
return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract());
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr& context) {
|
abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) {
|
||||||
AnalysisContextPtr temp_context = context;
|
AnalysisContextPtr temp_context = context;
|
||||||
if (temp_context == nullptr) {
|
if (temp_context == nullptr) {
|
||||||
temp_context = abstract::AnalysisContext::DummyContext();
|
temp_context = abstract::AnalysisContext::DummyContext();
|
||||||
|
@ -96,7 +96,7 @@ AnfNodePtr FuncGraph::output() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::set_output(const AnfNodePtr& value, bool force_new_ret) {
|
void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
|
||||||
if (force_new_ret || return_ == nullptr) {
|
if (force_new_ret || return_ == nullptr) {
|
||||||
std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
|
std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
|
||||||
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
|
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
|
||||||
|
@ -125,7 +125,7 @@ ParameterPtr FuncGraph::add_parameter() {
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::add_parameter(const ParameterPtr& p) {
|
void FuncGraph::add_parameter(const ParameterPtr &p) {
|
||||||
if (manager_.lock()) {
|
if (manager_.lock()) {
|
||||||
std::vector<AnfNodePtr> new_params = parameters_;
|
std::vector<AnfNodePtr> new_params = parameters_;
|
||||||
new_params.push_back(p);
|
new_params.push_back(p);
|
||||||
|
@ -135,7 +135,7 @@ void FuncGraph::add_parameter(const ParameterPtr& p) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) {
|
ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
|
||||||
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
|
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
|
||||||
ParameterPtr p = std::make_shared<Parameter>(this_graph);
|
ParameterPtr p = std::make_shared<Parameter>(this_graph);
|
||||||
p->set_name(name);
|
p->set_name(name);
|
||||||
|
@ -154,14 +154,14 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) {
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraph::has_flag(const std::string& flag) {
|
bool FuncGraph::has_flag(const std::string &flag) {
|
||||||
if (flags_.count(flag)) {
|
if (flags_.count(flag)) {
|
||||||
return flags_[flag];
|
return flags_[flag];
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) {
|
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
||||||
CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
|
CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
|
||||||
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
||||||
order_.push_back(cnode);
|
order_.push_back(cnode);
|
||||||
|
@ -170,7 +170,7 @@ CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) {
|
||||||
return cnode;
|
return cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, const ScopePtr& scope) {
|
CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope) {
|
||||||
CNodePtr app = NewCNode(inputs);
|
CNodePtr app = NewCNode(inputs);
|
||||||
app->set_scope(scope);
|
app->set_scope(scope);
|
||||||
return app;
|
return app;
|
||||||
|
@ -178,13 +178,13 @@ CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, con
|
||||||
|
|
||||||
void FuncGraph::DumpCNodeList() {
|
void FuncGraph::DumpCNodeList() {
|
||||||
MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
|
MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
|
||||||
for (const auto& cnode : order_) {
|
for (const auto &cnode : order_) {
|
||||||
MS_LOG(INFO) << cnode->DebugString();
|
MS_LOG(INFO) << cnode->DebugString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string FuncGraph::ToString() const {
|
std::string FuncGraph::ToString() const {
|
||||||
return mindspore::label_manage::Label(const_cast<FuncGraph*>(this)->shared_from_base<FuncGraph>()->debug_info());
|
return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphDebugInfoPtr FuncGraph::debug_info() {
|
GraphDebugInfoPtr FuncGraph::debug_info() {
|
||||||
|
@ -195,38 +195,38 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
|
||||||
return this->debug_info_;
|
return this->debug_info_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodeSet& FuncGraph::nodes() {
|
const AnfNodeSet &FuncGraph::nodes() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto& nodes = mng->nodes();
|
auto &nodes = mng->nodes();
|
||||||
return nodes[shared_from_base<FuncGraph>()];
|
return nodes[shared_from_base<FuncGraph>()];
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodeCounterMap& FuncGraph::value_nodes() {
|
const AnfNodeCounterMap &FuncGraph::value_nodes() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto& cts = mng->valuenodes();
|
auto &cts = mng->valuenodes();
|
||||||
return cts[shared_from_base<FuncGraph>()];
|
return cts[shared_from_base<FuncGraph>()];
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodeCounterMap& FuncGraph::free_variables_direct() {
|
const AnfNodeCounterMap &FuncGraph::free_variables_direct() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto& fv_direct = mng->free_variables_direct();
|
auto &fv_direct = mng->free_variables_direct();
|
||||||
return fv_direct[shared_from_base<FuncGraph>()];
|
return fv_direct[shared_from_base<FuncGraph>()];
|
||||||
}
|
}
|
||||||
|
|
||||||
const BaseRefCounterMap& FuncGraph::free_variables_total() {
|
const BaseRefCounterMap &FuncGraph::free_variables_total() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto& fv_total = mng->free_variables_total();
|
auto &fv_total = mng->free_variables_total();
|
||||||
return fv_total[shared_from_base<FuncGraph>()];
|
return fv_total[shared_from_base<FuncGraph>()];
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
|
std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
|
||||||
std::vector<AnfNodePtr> nodes;
|
std::vector<AnfNodePtr> nodes;
|
||||||
const auto& fv_total = this->free_variables_total();
|
const auto &fv_total = this->free_variables_total();
|
||||||
for (auto& p : fv_total) {
|
for (auto &p : fv_total) {
|
||||||
auto key = p.first;
|
auto key = p.first;
|
||||||
if (utils::isa<AnfNodePtr>(key)) {
|
if (utils::isa<AnfNodePtr>(key)) {
|
||||||
nodes.push_back(utils::cast<AnfNodePtr>(key));
|
nodes.push_back(utils::cast<AnfNodePtr>(key));
|
||||||
|
@ -238,8 +238,8 @@ std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
|
||||||
|
|
||||||
std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
|
std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
|
||||||
std::vector<FuncGraphPtr> func_graphs;
|
std::vector<FuncGraphPtr> func_graphs;
|
||||||
const auto& fv_total = this->free_variables_total();
|
const auto &fv_total = this->free_variables_total();
|
||||||
for (auto& p : fv_total) {
|
for (auto &p : fv_total) {
|
||||||
auto key = p.first;
|
auto key = p.first;
|
||||||
if (utils::isa<FuncGraphPtr>(key)) {
|
if (utils::isa<FuncGraphPtr>(key)) {
|
||||||
func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
|
func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
|
||||||
|
@ -249,31 +249,31 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
|
||||||
return func_graphs;
|
return func_graphs;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FuncGraphCounterMap& FuncGraph::func_graphs_used() {
|
const FuncGraphCounterMap &FuncGraph::func_graphs_used() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto& used = mng->func_graphs_used();
|
auto &used = mng->func_graphs_used();
|
||||||
return used[shared_from_base<FuncGraph>()];
|
return used[shared_from_base<FuncGraph>()];
|
||||||
}
|
}
|
||||||
|
|
||||||
const FuncGraphSet& FuncGraph::func_graphs_used_total() {
|
const FuncGraphSet &FuncGraph::func_graphs_used_total() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto& used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
|
auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
|
||||||
return used;
|
return used;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FuncGraphCounterMap& FuncGraph::func_graph_users() {
|
const FuncGraphCounterMap &FuncGraph::func_graph_users() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto& users = mng->func_graph_users();
|
auto &users = mng->func_graph_users();
|
||||||
return users[shared_from_base<FuncGraph>()];
|
return users[shared_from_base<FuncGraph>()];
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodeCounterMap& FuncGraph::func_graph_user_cnodes() {
|
const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto& users = mng->func_graph_user_cnodes();
|
auto &users = mng->func_graph_user_cnodes();
|
||||||
return users[shared_from_base<FuncGraph>()];
|
return users[shared_from_base<FuncGraph>()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,13 +288,13 @@ FuncGraphPtr FuncGraph::parent() {
|
||||||
return mng->parent(shared_from_base<FuncGraph>());
|
return mng->parent(shared_from_base<FuncGraph>());
|
||||||
}
|
}
|
||||||
|
|
||||||
const FuncGraphSet& FuncGraph::children() {
|
const FuncGraphSet &FuncGraph::children() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
return mng->children(shared_from_base<FuncGraph>());
|
return mng->children(shared_from_base<FuncGraph>());
|
||||||
}
|
}
|
||||||
|
|
||||||
const FuncGraphSet& FuncGraph::scope() {
|
const FuncGraphSet &FuncGraph::scope() {
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
return mng->scopes(shared_from_base<FuncGraph>());
|
return mng->scopes(shared_from_base<FuncGraph>());
|
||||||
|
@ -312,9 +312,9 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
|
||||||
return mng->recursive_graphs(shared_from_base<FuncGraph>());
|
return mng->recursive_graphs(shared_from_base<FuncGraph>());
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::DumpFuncGraph(const std::string& path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); }
|
void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); }
|
||||||
|
|
||||||
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) {
|
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
|
||||||
auto itr = this->parameter_default_value_.find(name);
|
auto itr = this->parameter_default_value_.find(name);
|
||||||
if (itr == parameter_default_value_.end()) {
|
if (itr == parameter_default_value_.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -330,9 +330,9 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// set the default values
|
// set the default values
|
||||||
void FuncGraph::SetDefaultValues(const std::vector<std::string>& name_list, const std::vector<AnfNodePtr>& value_list) {
|
void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) {
|
||||||
auto all_is_null = std::all_of(value_list.begin(), value_list.end(),
|
auto all_is_null = std::all_of(value_list.begin(), value_list.end(),
|
||||||
[](const AnfNodePtr& node) { return IsValueNode<NullObj>(node); });
|
[](const AnfNodePtr &node) { return IsValueNode<NullObj>(node); });
|
||||||
if (value_list.empty()) {
|
if (value_list.empty()) {
|
||||||
all_is_null = true;
|
all_is_null = true;
|
||||||
}
|
}
|
||||||
|
@ -348,7 +348,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
|
||||||
size_t FuncGraph::GetDefaultValueCount() {
|
size_t FuncGraph::GetDefaultValueCount() {
|
||||||
int null_count =
|
int null_count =
|
||||||
std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
|
std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
|
||||||
[](const std::pair<std::string, AnfNodePtr>& pair) { return IsValueNode<NullObj>(pair.second); });
|
[](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<NullObj>(pair.second); });
|
||||||
return parameter_default_value_.size() - IntToSize(null_count);
|
return parameter_default_value_.size() - IntToSize(null_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -425,7 +425,7 @@ int FuncGraph::GetPositionalArgsCount() const {
|
||||||
return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
|
return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) {
|
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
|
||||||
for (size_t i = 0; i < parameters_.size(); ++i) {
|
for (size_t i = 0; i < parameters_.size(); ++i) {
|
||||||
MS_EXCEPTION_IF_NULL(parameters_[i]);
|
MS_EXCEPTION_IF_NULL(parameters_[i]);
|
||||||
auto param_cast = parameters_[i]->cast<ParameterPtr>();
|
auto param_cast = parameters_[i]->cast<ParameterPtr>();
|
||||||
|
@ -437,9 +437,9 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph,
|
void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
|
||||||
std::vector<AnfNodePtr>* specialized_parameter_list,
|
std::vector<AnfNodePtr> *specialized_parameter_list,
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes, int variable_args_count,
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
|
||||||
int pos_args_input_count) {
|
int pos_args_input_count) {
|
||||||
// if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
|
// if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
|
||||||
if (specialized_graph->has_vararg()) {
|
if (specialized_graph->has_vararg()) {
|
||||||
|
@ -472,14 +472,14 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph,
|
void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
||||||
std::vector<AnfNodePtr>* specialized_parameter_list,
|
std::vector<AnfNodePtr> *specialized_parameter_list,
|
||||||
const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list,
|
const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) {
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
|
||||||
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
|
||||||
for (const auto& kwarg : kwarg_list) {
|
for (const auto &kwarg : kwarg_list) {
|
||||||
MS_EXCEPTION_IF_NULL(kwarg);
|
MS_EXCEPTION_IF_NULL(kwarg);
|
||||||
std::string kw_param_name = kwarg->get_key();
|
std::string kw_param_name = kwarg->get_key();
|
||||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||||
|
@ -493,7 +493,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph,
|
||||||
std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
|
std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
|
||||||
MS_EXCEPTION_IF_NULL(specialized_parameter_list);
|
MS_EXCEPTION_IF_NULL(specialized_parameter_list);
|
||||||
auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
|
auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
|
||||||
[param_name](const AnfNodePtr& node) {
|
[param_name](const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto param = node->cast<ParameterPtr>();
|
auto param = node->cast<ParameterPtr>();
|
||||||
return param != nullptr && param->name() == param_name;
|
return param != nullptr && param->name() == param_name;
|
||||||
|
@ -526,10 +526,10 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph,
|
||||||
GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes);
|
GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph,
|
void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes,
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
|
||||||
const std::vector<AnfNodePtr>& kwarg_keys_tuple_nodes,
|
const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
|
||||||
const std::vector<AnfNodePtr>& kwarg_values_tuple_nodes) {
|
const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes) {
|
||||||
if (has_kwarg()) {
|
if (has_kwarg()) {
|
||||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||||
TraceManager::DebugTrace(
|
TraceManager::DebugTrace(
|
||||||
|
@ -544,7 +544,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list) {
|
bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
|
||||||
// if the function does not have any vararg/kwarg/kwonly/default value/kw args input
|
// if the function does not have any vararg/kwarg/kwonly/default value/kw args input
|
||||||
// return the original graph
|
// return the original graph
|
||||||
if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
|
if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
|
||||||
|
@ -558,9 +558,9 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>&
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph,
|
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
|
||||||
const std::vector<AnfNodePtr>& specialized_parameter_list,
|
const std::vector<AnfNodePtr> &specialized_parameter_list,
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) {
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||||
for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
|
for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
|
||||||
auto param_node = specialized_graph->parameters()[i];
|
auto param_node = specialized_graph->parameters()[i];
|
||||||
|
@ -583,10 +583,10 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
|
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
|
||||||
size_t arguments_count = args_spec_list.size();
|
size_t arguments_count = args_spec_list.size();
|
||||||
for (const auto& arg : args_spec_list) {
|
for (const auto &arg : args_spec_list) {
|
||||||
// if it is a keyword argument
|
// if it is a keyword argument
|
||||||
MS_EXCEPTION_IF_NULL(arg);
|
MS_EXCEPTION_IF_NULL(arg);
|
||||||
if (arg->isa<abstract::AbstractKeywordArg>()) {
|
if (arg->isa<abstract::AbstractKeywordArg>()) {
|
||||||
|
@ -619,11 +619,11 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list)
|
||||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||||
auto params = specialized_graph->parameters();
|
auto params = specialized_graph->parameters();
|
||||||
(void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(),
|
(void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(),
|
||||||
std::back_inserter(specialized_parameter_list), [](const AnfNodePtr& node) { return node; });
|
std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; });
|
||||||
|
|
||||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
|
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
|
||||||
auto tr = manager->Transact();
|
auto tr = manager->Transact();
|
||||||
for (auto& node_pair : repl_nodes) {
|
for (auto &node_pair : repl_nodes) {
|
||||||
MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-"
|
MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-"
|
||||||
<< node_pair.second->DebugString();
|
<< node_pair.second->DebugString();
|
||||||
(void)tr.Replace(node_pair.first, node_pair.second);
|
(void)tr.Replace(node_pair.first, node_pair.second);
|
||||||
|
@ -638,7 +638,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list)
|
||||||
return specialized_graph;
|
return specialized_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::add_parameter_obj_node(const AnfNodePtr& p) { paramter_obj_nodes_.push_back(p); }
|
void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
|
||||||
|
|
||||||
std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
|
std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
|
||||||
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
||||||
|
@ -651,7 +651,7 @@ std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
|
||||||
|
|
||||||
std::list<CNodePtr> cnodes;
|
std::list<CNodePtr> cnodes;
|
||||||
auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph);
|
auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph);
|
||||||
for (const auto& node : nodes) {
|
for (const auto &node : nodes) {
|
||||||
auto cnode = dyn_cast<CNode>(node);
|
auto cnode = dyn_cast<CNode>(node);
|
||||||
if (cnode) {
|
if (cnode) {
|
||||||
cnodes.push_back(cnode);
|
cnodes.push_back(cnode);
|
||||||
|
@ -679,7 +679,7 @@ void FuncGraph::EraseUnusedNodeInOrder() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr& n) {
|
void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) {
|
||||||
if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) {
|
if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) {
|
||||||
order_.remove(n->cast<CNodePtr>());
|
order_.remove(n->cast<CNodePtr>());
|
||||||
MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list.";
|
MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list.";
|
||||||
|
@ -690,7 +690,7 @@ void FuncGraph::CheckOrder() {
|
||||||
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
||||||
MS_LOG(DEBUG) << "Check graph " << ToString();
|
MS_LOG(DEBUG) << "Check graph " << ToString();
|
||||||
for (auto it = order_.begin(); it != order_.end(); (void)it++) {
|
for (auto it = order_.begin(); it != order_.end(); (void)it++) {
|
||||||
for (const auto& input_node : (*it)->inputs()) {
|
for (const auto &input_node : (*it)->inputs()) {
|
||||||
if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) {
|
if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) {
|
||||||
// Need to reorder the wrong order node.
|
// Need to reorder the wrong order node.
|
||||||
auto found = std::find(order_.begin(), it, input_node);
|
auto found = std::find(order_.begin(), it, input_node);
|
||||||
|
@ -705,7 +705,7 @@ void FuncGraph::CheckOrder() {
|
||||||
}
|
}
|
||||||
auto mng = manager_.lock();
|
auto mng = manager_.lock();
|
||||||
if (mng != nullptr) {
|
if (mng != nullptr) {
|
||||||
const auto& nodes = mng->nodes()[shared_from_base<FuncGraph>()];
|
const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()];
|
||||||
if (nodes.size() != (order_.size() + parameters_.size())) {
|
if (nodes.size() != (order_.size() + parameters_.size())) {
|
||||||
DumpCNodeList();
|
DumpCNodeList();
|
||||||
MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
|
MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
|
||||||
|
@ -718,7 +718,7 @@ void FuncGraph::CheckOrder() {
|
||||||
|
|
||||||
const char kPrimHasEffect[] = "_side_effect_flag";
|
const char kPrimHasEffect[] = "_side_effect_flag";
|
||||||
|
|
||||||
bool FuncGraph::HasEffect(const CNodePtr& cnode) {
|
bool FuncGraph::HasEffect(const CNodePtr &cnode) {
|
||||||
auto prim = GetCNodePrimitive(cnode);
|
auto prim = GetCNodePrimitive(cnode);
|
||||||
if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) {
|
if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) {
|
||||||
auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>();
|
auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>();
|
||||||
|
@ -739,9 +739,9 @@ bool FuncGraph::HasEffect(const CNodePtr& cnode) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& segment) {
|
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
|
||||||
std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
|
std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
|
||||||
for (const auto& node : segment) {
|
for (const auto &node : segment) {
|
||||||
if (roots->size() == 1) {
|
if (roots->size() == 1) {
|
||||||
return roots;
|
return roots;
|
||||||
}
|
}
|
||||||
|
@ -757,9 +757,9 @@ std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& seg
|
||||||
return roots;
|
return roots;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr>& segment) {
|
std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
|
||||||
std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
|
std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
|
||||||
for (const auto& node : segment) {
|
for (const auto &node : segment) {
|
||||||
if (nodes->size() == 1) {
|
if (nodes->size() == 1) {
|
||||||
return nodes;
|
return nodes;
|
||||||
}
|
}
|
||||||
|
@ -790,7 +790,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
|
||||||
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
||||||
std::list<AnfNodePtr> depends_order;
|
std::list<AnfNodePtr> depends_order;
|
||||||
std::vector<CNodePtr> segment;
|
std::vector<CNodePtr> segment;
|
||||||
for (const auto& cnode : order_) {
|
for (const auto &cnode : order_) {
|
||||||
if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) {
|
if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -830,7 +830,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr>& depend_inputs) {
|
void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
|
||||||
auto old_ret = output();
|
auto old_ret = output();
|
||||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret};
|
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret};
|
||||||
(void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end());
|
(void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end());
|
||||||
|
|
|
@ -26,29 +26,29 @@
|
||||||
|
|
||||||
// namespace to support intermediate representation definition
|
// namespace to support intermediate representation definition
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
Cloner::Cloner(const FuncGraphPtrList& func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
|
Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
|
||||||
bool clone_all_used_graphs, const TraceInfoPtr& relation, const TraceInfoPtr& target_relation)
|
bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
|
||||||
: clone_all_valuenodes_(clone_all_valuenodes),
|
: clone_all_valuenodes_(clone_all_valuenodes),
|
||||||
clone_all_child_graphs_(clone_all_child_graphs),
|
clone_all_child_graphs_(clone_all_child_graphs),
|
||||||
clone_all_used_graphs_(clone_all_used_graphs),
|
clone_all_used_graphs_(clone_all_used_graphs),
|
||||||
relation_(relation),
|
relation_(relation),
|
||||||
target_relation_(target_relation == nullptr ? relation : target_relation) {
|
target_relation_(target_relation == nullptr ? relation : target_relation) {
|
||||||
for (auto& func_graph : func_graphs) {
|
for (auto &func_graph : func_graphs) {
|
||||||
AddClone(func_graph);
|
AddClone(func_graph);
|
||||||
}
|
}
|
||||||
scope_ = kDefaultScope;
|
scope_ = kDefaultScope;
|
||||||
type_ = kBasic;
|
type_ = kBasic;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph,
|
void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
|
||||||
const AnfNodePtrList& params, CloneType type) {
|
const AnfNodePtrList ¶ms, CloneType type) {
|
||||||
if (func_graph != nullptr) {
|
if (func_graph != nullptr) {
|
||||||
todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params});
|
todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params});
|
||||||
type_ = type;
|
type_ = type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
|
void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
|
if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
|
||||||
return;
|
return;
|
||||||
|
@ -60,7 +60,7 @@ void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add) {
|
void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(target);
|
MS_EXCEPTION_IF_NULL(target);
|
||||||
TraceManager::DebugTrace(node->debug_info(), relation_);
|
TraceManager::DebugTrace(node->debug_info(), relation_);
|
||||||
|
@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target,
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
|
void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(target);
|
MS_EXCEPTION_IF_NULL(target);
|
||||||
TraceManager::DebugTrace(node->debug_info(), relation_);
|
TraceManager::DebugTrace(node->debug_info(), relation_);
|
||||||
|
@ -91,7 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneValueNode(const AnfNodePtr& node) {
|
void Cloner::CloneValueNode(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
TraceManager::DebugTrace(node->debug_info(), relation_);
|
TraceManager::DebugTrace(node->debug_info(), relation_);
|
||||||
ValueNodePtr new_const = NewValueNode(GetValueNode(node));
|
ValueNodePtr new_const = NewValueNode(GetValueNode(node));
|
||||||
|
@ -102,7 +102,7 @@ void Cloner::CloneValueNode(const AnfNodePtr& node) {
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
|
void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(target);
|
MS_EXCEPTION_IF_NULL(target);
|
||||||
TraceManager::DebugTrace(node->debug_info(), relation_);
|
TraceManager::DebugTrace(node->debug_info(), relation_);
|
||||||
|
@ -114,14 +114,14 @@ void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target)
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) {
|
void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
if (!clone_all_valuenodes_) {
|
if (!clone_all_valuenodes_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& value_nodes = manager_->valuenodes()[func_graph];
|
auto &value_nodes = manager_->valuenodes()[func_graph];
|
||||||
for (auto& value_node : value_nodes) {
|
for (auto &value_node : value_nodes) {
|
||||||
auto old_node = value_node.first;
|
auto old_node = value_node.first;
|
||||||
MS_EXCEPTION_IF_NULL(old_node);
|
MS_EXCEPTION_IF_NULL(old_node);
|
||||||
if (repl_node_.count(old_node) == 0) {
|
if (repl_node_.count(old_node) == 0) {
|
||||||
|
@ -130,38 +130,38 @@ void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::AddChildGraphs(const FuncGraphPtr& func_graph) {
|
void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
if (!clone_all_child_graphs_) {
|
if (!clone_all_child_graphs_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& scopes = manager_->scopes(func_graph);
|
auto &scopes = manager_->scopes(func_graph);
|
||||||
for (auto& graph : scopes) {
|
for (auto &graph : scopes) {
|
||||||
if (graph != func_graph) {
|
if (graph != func_graph) {
|
||||||
todo_.push_back({graph, nullptr, {}});
|
todo_.push_back({graph, nullptr, {}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::AddTotalGraphs(const FuncGraphPtr& func_graph) {
|
void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
if (!clone_all_used_graphs_) {
|
if (!clone_all_used_graphs_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& used_graphs = manager_->func_graphs_used()[func_graph];
|
auto &used_graphs = manager_->func_graphs_used()[func_graph];
|
||||||
for (auto& used_graph : used_graphs) {
|
for (auto &used_graph : used_graphs) {
|
||||||
todo_.push_back({used_graph.first, nullptr, {}});
|
todo_.push_back({used_graph.first, nullptr, {}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) {
|
void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||||
for (auto& item : func_graph->parameter_default_value()) {
|
for (auto &item : func_graph->parameter_default_value()) {
|
||||||
auto nodes = DeepLinkedGraphSearch(item.second);
|
auto nodes = DeepLinkedGraphSearch(item.second);
|
||||||
for (auto& node : nodes) {
|
for (auto &node : nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
CloneNode(node, target_func_graph);
|
CloneNode(node, target_func_graph);
|
||||||
|
@ -172,7 +172,7 @@ void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const F
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) {
|
void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
|
@ -182,15 +182,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const Func
|
||||||
}
|
}
|
||||||
target_func_graph->set_return(return_node);
|
target_func_graph->set_return(return_node);
|
||||||
|
|
||||||
auto& value_nodes = manager_->func_graph_valuenodes()[func_graph];
|
auto &value_nodes = manager_->func_graph_valuenodes()[func_graph];
|
||||||
for (auto& value_node : value_nodes) {
|
for (auto &value_node : value_nodes) {
|
||||||
CloneValueNode(value_node.first, target_func_graph);
|
CloneValueNode(value_node.first, target_func_graph);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params) {
|
void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto& old_params = func_graph->parameters();
|
auto &old_params = func_graph->parameters();
|
||||||
if (old_params.size() != params.size()) {
|
if (old_params.size() != params.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]";
|
MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]";
|
||||||
return;
|
return;
|
||||||
|
@ -200,7 +200,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph) {
|
void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||||
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
|
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
|
||||||
|
@ -215,33 +215,33 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* cons
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) {
|
void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||||
auto& params = func_graph->parameters();
|
auto ¶ms = func_graph->parameters();
|
||||||
for (auto& param : params) {
|
for (auto ¶m : params) {
|
||||||
CloneParameter(param, target_func_graph, true);
|
CloneParameter(param, target_func_graph, true);
|
||||||
}
|
}
|
||||||
repl_func_graph_[func_graph] = target_func_graph;
|
repl_func_graph_[func_graph] = target_func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::GenParameters(const FuncGraphPtr& func_graph) {
|
void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto& free_vars = manager_->free_variables_total();
|
auto &free_vars = manager_->free_variables_total();
|
||||||
auto iter = free_vars.find(func_graph);
|
auto iter = free_vars.find(func_graph);
|
||||||
if (iter == free_vars.end()) {
|
if (iter == free_vars.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& fv_map : iter->second) {
|
for (auto &fv_map : iter->second) {
|
||||||
auto& free_var = fv_map.first;
|
auto &free_var = fv_map.first;
|
||||||
if (utils::isa<AnfNodePtr>(free_var)) {
|
if (utils::isa<AnfNodePtr>(free_var)) {
|
||||||
repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
|
repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) {
|
void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) {
|
||||||
param->set_abstract(node->abstract());
|
param->set_abstract(node->abstract());
|
||||||
if (node->isa<Parameter>()) {
|
if (node->isa<Parameter>()) {
|
||||||
ParameterPtr old_param = dyn_cast<Parameter>(node);
|
ParameterPtr old_param = dyn_cast<Parameter>(node);
|
||||||
|
@ -252,7 +252,7 @@ void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add) {
|
ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
|
||||||
TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info()));
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info()));
|
||||||
ParameterPtr param = std::make_shared<Parameter>(func_graph);
|
ParameterPtr param = std::make_shared<Parameter>(func_graph);
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
|
@ -265,11 +265,11 @@ ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodeP
|
||||||
return param;
|
return param;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params,
|
void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms,
|
||||||
AnfNodePtrList* const lift_params, AnfNodePtrList* const input_params) {
|
AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) {
|
||||||
AnfNodePtrList parameters;
|
AnfNodePtrList parameters;
|
||||||
std::unordered_set<AnfNodePtr> old_params;
|
std::unordered_set<AnfNodePtr> old_params;
|
||||||
for (auto& param : func_graph->parameters()) {
|
for (auto ¶m : func_graph->parameters()) {
|
||||||
auto iter = repl_node_.find(param);
|
auto iter = repl_node_.find(param);
|
||||||
if (iter != repl_node_.end()) {
|
if (iter != repl_node_.end()) {
|
||||||
(void)old_params.insert(iter->second);
|
(void)old_params.insert(iter->second);
|
||||||
|
@ -280,7 +280,7 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList&
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnfNodePtr new_param = nullptr;
|
AnfNodePtr new_param = nullptr;
|
||||||
for (auto& param : params) {
|
for (auto ¶m : params) {
|
||||||
auto old_param = repl_node_[param];
|
auto old_param = repl_node_[param];
|
||||||
if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
|
if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
|
||||||
repl_node_[old_param] = old_param;
|
repl_node_[old_param] = old_param;
|
||||||
|
@ -301,10 +301,10 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList&
|
||||||
func_graph->set_parameters(parameters);
|
func_graph->set_parameters(parameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph,
|
void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
|
||||||
const AnfNodePtrList& params) {
|
const AnfNodePtrList ¶ms) {
|
||||||
AnfNodePtr node = nullptr;
|
AnfNodePtr node = nullptr;
|
||||||
auto& repl_func_graph = repl_map_func_graph_[func_graph_user];
|
auto &repl_func_graph = repl_map_func_graph_[func_graph_user];
|
||||||
auto iter = repl_func_graph.find(func_graph);
|
auto iter = repl_func_graph.find(func_graph);
|
||||||
if (iter == repl_func_graph.end()) {
|
if (iter == repl_func_graph.end()) {
|
||||||
node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
|
node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
|
||||||
|
@ -322,9 +322,9 @@ void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr&
|
||||||
OrderParameters(func_graph, inputs);
|
OrderParameters(func_graph, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs) {
|
void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
|
||||||
std::unordered_set<AnfNodePtr> old_params;
|
std::unordered_set<AnfNodePtr> old_params;
|
||||||
for (auto& param : func_graph->parameters()) {
|
for (auto ¶m : func_graph->parameters()) {
|
||||||
(void)old_params.insert(repl_node_[param]);
|
(void)old_params.insert(repl_node_[param]);
|
||||||
}
|
}
|
||||||
std::unordered_set<AnfNodePtr> new_params;
|
std::unordered_set<AnfNodePtr> new_params;
|
||||||
|
@ -339,7 +339,7 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis
|
||||||
(void)new_params.insert(new_param);
|
(void)new_params.insert(new_param);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& param : func_graph->parameters()) {
|
for (auto ¶m : func_graph->parameters()) {
|
||||||
if (new_params.find(param) == new_params.end()) {
|
if (new_params.find(param) == new_params.end()) {
|
||||||
parameters.push_back(param);
|
parameters.push_back(param);
|
||||||
}
|
}
|
||||||
|
@ -347,9 +347,9 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis
|
||||||
func_graph->set_parameters(parameters);
|
func_graph->set_parameters(parameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::SetEdges(const FuncGraphPtr& func_graph) {
|
void Cloner::SetEdges(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
for (auto& node : func_graph->nodes()) {
|
for (auto &node : func_graph->nodes()) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -358,17 +358,17 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
auto& inputs = cnode->inputs();
|
auto &inputs = cnode->inputs();
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
auto& input = inputs[i];
|
auto &input = inputs[i];
|
||||||
if (IsValueNode<FuncGraph>(input)) {
|
if (IsValueNode<FuncGraph>(input)) {
|
||||||
auto graph = GetValueNode<FuncGraphPtr>(input);
|
auto graph = GetValueNode<FuncGraphPtr>(input);
|
||||||
auto& repl_func_graph = repl_map_func_graph_[func_graph];
|
auto &repl_func_graph = repl_map_func_graph_[func_graph];
|
||||||
if (repl_func_graph.find(graph) != repl_func_graph.end()) {
|
if (repl_func_graph.find(graph) != repl_func_graph.end()) {
|
||||||
transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
|
transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto& repl_node = repl_map_node_[func_graph];
|
auto &repl_node = repl_map_node_[func_graph];
|
||||||
if (repl_node.find(input) != repl_node.end()) {
|
if (repl_node.find(input) != repl_node.end()) {
|
||||||
transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]);
|
transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]);
|
||||||
}
|
}
|
||||||
|
@ -377,8 +377,8 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph,
|
void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
|
||||||
const AnfNodePtrList& params) {
|
const AnfNodePtrList ¶ms) {
|
||||||
AnfNodePtrList lift_params;
|
AnfNodePtrList lift_params;
|
||||||
AnfNodePtrList input_params;
|
AnfNodePtrList input_params;
|
||||||
AddParameters(func_graph_user, params, &lift_params, &input_params);
|
AddParameters(func_graph_user, params, &lift_params, &input_params);
|
||||||
|
@ -386,16 +386,16 @@ void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraph
|
||||||
if (lift_params.empty()) {
|
if (lift_params.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (auto& user : func_graph_user->func_graph_users()) {
|
for (auto &user : func_graph_user->func_graph_users()) {
|
||||||
LiftParameters(user.first, func_graph_user, lift_params);
|
LiftParameters(user.first, func_graph_user, lift_params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::Lift() {
|
void Cloner::Lift() {
|
||||||
for (auto& func_graph_params : repl_func_graph_params_) {
|
for (auto &func_graph_params : repl_func_graph_params_) {
|
||||||
auto& func_graph = func_graph_params.first;
|
auto &func_graph = func_graph_params.first;
|
||||||
auto& params = func_graph_params.second;
|
auto ¶ms = func_graph_params.second;
|
||||||
for (auto& user : func_graph->func_graph_users()) {
|
for (auto &user : func_graph->func_graph_users()) {
|
||||||
LiftParameters(user.first, func_graph, params);
|
LiftParameters(user.first, func_graph, params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -404,18 +404,18 @@ void Cloner::Lift() {
|
||||||
void Cloner::LiftParameters() {
|
void Cloner::LiftParameters() {
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
transaction_ = manager_->Transact();
|
transaction_ = manager_->Transact();
|
||||||
const FuncGraphSet& func_graphs = manager_->func_graphs();
|
const FuncGraphSet &func_graphs = manager_->func_graphs();
|
||||||
for (auto& func_graph : func_graphs) {
|
for (auto &func_graph : func_graphs) {
|
||||||
GenParameters(func_graph);
|
GenParameters(func_graph);
|
||||||
}
|
}
|
||||||
Lift();
|
Lift();
|
||||||
for (auto& func_graph : func_graphs) {
|
for (auto &func_graph : func_graphs) {
|
||||||
SetEdges(func_graph);
|
SetEdges(func_graph);
|
||||||
}
|
}
|
||||||
transaction_.Commit();
|
transaction_.Commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) {
|
bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
// Make sure only inline once
|
// Make sure only inline once
|
||||||
if (status_.count(func_graph) != 0) {
|
if (status_.count(func_graph) != 0) {
|
||||||
|
@ -430,12 +430,12 @@ bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) {
|
void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
const AnfNodeSet& nodes = manager_->nodes()[func_graph];
|
const AnfNodeSet &nodes = manager_->nodes()[func_graph];
|
||||||
for (auto& node : nodes) {
|
for (auto &node : nodes) {
|
||||||
CloneNode(node, target_func_graph);
|
CloneNode(node, target_func_graph);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -449,7 +449,7 @@ void Cloner::Run() {
|
||||||
// Basic and Inline Clone
|
// Basic and Inline Clone
|
||||||
FuncGraphPtrList func_graphs;
|
FuncGraphPtrList func_graphs;
|
||||||
(void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
|
(void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
|
||||||
[](const CloneInfo& item) -> FuncGraphPtr { return item.origin; });
|
[](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
|
||||||
manager_ = Manage(func_graphs, false);
|
manager_ = Manage(func_graphs, false);
|
||||||
CloneNodes();
|
CloneNodes();
|
||||||
LinkEdges();
|
LinkEdges();
|
||||||
|
@ -495,13 +495,13 @@ void Cloner::CloneNodes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::LinkEdges() {
|
void Cloner::LinkEdges() {
|
||||||
for (auto& node_pair : nodes_) {
|
for (auto &node_pair : nodes_) {
|
||||||
CNodePtr old_node = node_pair.first;
|
CNodePtr old_node = node_pair.first;
|
||||||
CNodePtr new_node = node_pair.second;
|
CNodePtr new_node = node_pair.second;
|
||||||
MS_EXCEPTION_IF_NULL(old_node);
|
MS_EXCEPTION_IF_NULL(old_node);
|
||||||
MS_EXCEPTION_IF_NULL(new_node);
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
for (auto& input : old_node->inputs()) {
|
for (auto &input : old_node->inputs()) {
|
||||||
auto& new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
|
auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
|
||||||
new_node->add_input(new_input);
|
new_node->add_input(new_input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -509,10 +509,10 @@ void Cloner::LinkEdges() {
|
||||||
|
|
||||||
// For the graphs cloned, update its default value map to the cloned nodes
|
// For the graphs cloned, update its default value map to the cloned nodes
|
||||||
void Cloner::SetDefaults() {
|
void Cloner::SetDefaults() {
|
||||||
for (auto& item : graph_set_) {
|
for (auto &item : graph_set_) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
if (repl_func_graph_.count(item) != 0) {
|
if (repl_func_graph_.count(item) != 0) {
|
||||||
for (auto& param_def : item->parameter_default_value()) {
|
for (auto ¶m_def : item->parameter_default_value()) {
|
||||||
MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
|
MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
|
||||||
if (repl_node_.count(param_def.second) != 0) {
|
if (repl_node_.count(param_def.second) != 0) {
|
||||||
repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
|
repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
|
||||||
|
@ -524,7 +524,7 @@ void Cloner::SetDefaults() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) {
|
AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
|
||||||
MS_EXCEPTION_IF_NULL(root);
|
MS_EXCEPTION_IF_NULL(root);
|
||||||
if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
|
if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
|
||||||
MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
|
MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
|
||||||
|
@ -537,7 +537,7 @@ AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) {
|
||||||
MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
|
MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr Cloner::operator[](const AnfNodePtr& node) {
|
AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
|
||||||
#ifdef ENABLE_PROFILE
|
#ifdef ENABLE_PROFILE
|
||||||
double time = GetTime();
|
double time = GetTime();
|
||||||
#endif
|
#endif
|
||||||
|
@ -548,7 +548,7 @@ AnfNodePtr Cloner::operator[](const AnfNodePtr& node) {
|
||||||
return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
|
return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) {
|
FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
|
||||||
#ifdef ENABLE_PROFILE
|
#ifdef ENABLE_PROFILE
|
||||||
double time = GetTime();
|
double time = GetTime();
|
||||||
#endif
|
#endif
|
||||||
|
@ -559,14 +559,14 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) {
|
||||||
return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
|
return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph) {
|
FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
|
Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
|
||||||
return cloner[func_graph];
|
return cloner[func_graph];
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph,
|
AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
|
||||||
const AnfNodePtrList& func_graph_args, const ScopePtr& scope) {
|
const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||||
Cloner cloner({}, false);
|
Cloner cloner({}, false);
|
||||||
|
@ -577,14 +577,14 @@ AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& targe
|
||||||
return cloner[func_graph->output()];
|
return cloner[func_graph->output()];
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph) {
|
FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
Cloner cloner({}, false);
|
Cloner cloner({}, false);
|
||||||
cloner.AddClone(func_graph, nullptr, {}, kLifting);
|
cloner.AddClone(func_graph, nullptr, {}, kLifting);
|
||||||
return cloner[func_graph];
|
return cloner[func_graph];
|
||||||
}
|
}
|
||||||
|
|
||||||
ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) {
|
ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
FuncGraphPtrList func_graphs = {func_graph};
|
FuncGraphPtrList func_graphs = {func_graph};
|
||||||
ClonerPtr cloner =
|
ClonerPtr cloner =
|
||||||
|
@ -599,14 +599,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& r
|
||||||
return cloner;
|
return cloner;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) {
|
FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
TraceManager::DebugTrace(func_graph->debug_info(), relation);
|
TraceManager::DebugTrace(func_graph->debug_info(), relation);
|
||||||
auto new_func_graph = std::make_shared<FuncGraph>();
|
auto new_func_graph = std::make_shared<FuncGraph>();
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
|
|
||||||
auto& parameters = func_graph->parameters();
|
auto ¶meters = func_graph->parameters();
|
||||||
(void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr& param) -> void {
|
(void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void {
|
||||||
MS_EXCEPTION_IF_NULL(param);
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info()));
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info()));
|
||||||
(void)new_func_graph->add_parameter();
|
(void)new_func_graph->add_parameter();
|
||||||
|
@ -622,7 +622,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoP
|
||||||
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
|
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
|
||||||
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
||||||
new_func_graph->set_is_generate(func_graph->is_generated());
|
new_func_graph->set_is_generate(func_graph->is_generated());
|
||||||
for (auto& item : func_graph->parameter_default_value()) {
|
for (auto &item : func_graph->parameter_default_value()) {
|
||||||
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,26 +43,26 @@ struct CloneInfo {
|
||||||
|
|
||||||
class Cloner {
|
class Cloner {
|
||||||
public:
|
public:
|
||||||
explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false,
|
explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false,
|
||||||
bool clone_all_child_graphs = true, bool clone_all_used_graphs = false,
|
bool clone_all_child_graphs = true, bool clone_all_used_graphs = false,
|
||||||
const TraceInfoPtr& relation = std::make_shared<TraceCopy>(),
|
const TraceInfoPtr &relation = std::make_shared<TraceCopy>(),
|
||||||
const TraceInfoPtr& target_relation = nullptr);
|
const TraceInfoPtr &target_relation = nullptr);
|
||||||
~Cloner() = default;
|
~Cloner() = default;
|
||||||
void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr,
|
void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr,
|
||||||
const AnfNodePtrList& params = {}, CloneType type = kBasic);
|
const AnfNodePtrList ¶ms = {}, CloneType type = kBasic);
|
||||||
void Run();
|
void Run();
|
||||||
|
|
||||||
// Interfaces for specializer
|
// Interfaces for specializer
|
||||||
AnfNodePtr CloneDisconnected(const AnfNodePtr& root);
|
AnfNodePtr CloneDisconnected(const AnfNodePtr &root);
|
||||||
AnfNodePtr operator[](const AnfNodePtr& node);
|
AnfNodePtr operator[](const AnfNodePtr &node);
|
||||||
FuncGraphPtr operator[](const FuncGraphPtr& func_graph);
|
FuncGraphPtr operator[](const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
// Map of replicate nodes and graphs
|
// Map of replicate nodes and graphs
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr>* cloned_node() { return &repl_node_; }
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; }
|
||||||
std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; }
|
std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; }
|
||||||
|
|
||||||
// Scope of cloned graphs
|
// Scope of cloned graphs
|
||||||
void set_scope(const ScopePtr& scope) { scope_ = scope; }
|
void set_scope(const ScopePtr &scope) { scope_ = scope; }
|
||||||
const ScopePtr scope() const { return scope_; }
|
const ScopePtr scope() const { return scope_; }
|
||||||
|
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_;
|
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_;
|
||||||
|
@ -71,31 +71,31 @@ class Cloner {
|
||||||
void CloneNodes();
|
void CloneNodes();
|
||||||
void LinkEdges();
|
void LinkEdges();
|
||||||
void SetDefaults();
|
void SetDefaults();
|
||||||
void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target);
|
void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target);
|
||||||
void CloneValueNode(const AnfNodePtr& node);
|
void CloneValueNode(const AnfNodePtr &node);
|
||||||
void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target);
|
void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
|
||||||
void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target);
|
void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target);
|
||||||
void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false);
|
void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false);
|
||||||
void CloneValueNodes(const FuncGraphPtr& func_graph);
|
void CloneValueNodes(const FuncGraphPtr &func_graph);
|
||||||
void AddChildGraphs(const FuncGraphPtr& func_graph);
|
void AddChildGraphs(const FuncGraphPtr &func_graph);
|
||||||
void AddTotalGraphs(const FuncGraphPtr& func_graph);
|
void AddTotalGraphs(const FuncGraphPtr &func_graph);
|
||||||
bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline);
|
bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline);
|
||||||
void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
|
void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
|
||||||
void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
|
void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
|
||||||
void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
|
void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
|
||||||
void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params);
|
void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms);
|
||||||
void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph);
|
void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph);
|
||||||
void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
|
void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
|
||||||
void GenParameters(const FuncGraphPtr& func_graph);
|
void GenParameters(const FuncGraphPtr &func_graph);
|
||||||
void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node);
|
void CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node);
|
||||||
ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true);
|
ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true);
|
||||||
void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params,
|
void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, AnfNodePtrList *const lift_params,
|
||||||
AnfNodePtrList* const input_params);
|
AnfNodePtrList *const input_params);
|
||||||
void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params);
|
void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms);
|
||||||
void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs);
|
void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs);
|
||||||
void SetEdges(const FuncGraphPtr& func_graph);
|
void SetEdges(const FuncGraphPtr &func_graph);
|
||||||
void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph,
|
void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
|
||||||
const AnfNodePtrList& params);
|
const AnfNodePtrList ¶ms);
|
||||||
void Lift();
|
void Lift();
|
||||||
void LiftParameters();
|
void LiftParameters();
|
||||||
|
|
||||||
|
@ -118,17 +118,17 @@ class Cloner {
|
||||||
std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_;
|
std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_;
|
||||||
};
|
};
|
||||||
|
|
||||||
FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph);
|
FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph,
|
AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
|
||||||
const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr);
|
const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr);
|
||||||
|
|
||||||
FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph);
|
FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation);
|
ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation);
|
||||||
|
|
||||||
FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph,
|
FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph,
|
||||||
const TraceInfoPtr& relation = std::make_shared<TraceTransform>());
|
const TraceInfoPtr &relation = std::make_shared<TraceTransform>());
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_
|
#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_
|
||||||
|
|
|
@ -27,17 +27,17 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs, bool manage) {
|
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
|
||||||
auto m = std::make_shared<FuncGraphManager>(func_graphs, manage);
|
auto m = std::make_shared<FuncGraphManager>(func_graphs, manage);
|
||||||
m->Init();
|
m->Init();
|
||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage) {
|
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
|
||||||
FuncGraphManagerPtr m = nullptr;
|
FuncGraphManagerPtr m = nullptr;
|
||||||
bool root = false;
|
bool root = false;
|
||||||
|
|
||||||
for (auto& fg : func_graphs) {
|
for (auto &fg : func_graphs) {
|
||||||
if (fg == nullptr) {
|
if (fg == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool ma
|
||||||
root = true;
|
root = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& fg : func_graphs) {
|
for (auto &fg : func_graphs) {
|
||||||
if (fg == nullptr) {
|
if (fg == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) {
|
||||||
return Manage(func_graphs, manage);
|
return Manage(func_graphs, manage);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage)
|
FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage)
|
||||||
: roots_(roots), is_manage_(manage) {
|
: roots_(roots), is_manage_(manage) {
|
||||||
Reset();
|
Reset();
|
||||||
}
|
}
|
||||||
|
@ -103,12 +103,12 @@ void FuncGraphManager::Init() {
|
||||||
auto roots = roots_;
|
auto roots = roots_;
|
||||||
roots_ = FuncGraphSet();
|
roots_ = FuncGraphSet();
|
||||||
|
|
||||||
for (auto& fg : roots) {
|
for (auto &fg : roots) {
|
||||||
AddFuncGraph(fg, true);
|
AddFuncGraph(fg, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) const {
|
FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
|
MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
|
||||||
func_graph_parents_total_->Recompute(fg);
|
func_graph_parents_total_->Recompute(fg);
|
||||||
|
@ -116,7 +116,7 @@ FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg)
|
||||||
return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
|
return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const {
|
FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
MS_EXCEPTION_IF_NULL(func_graph_parent_);
|
MS_EXCEPTION_IF_NULL(func_graph_parent_);
|
||||||
MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
|
MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
|
||||||
|
@ -129,7 +129,7 @@ FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const {
|
||||||
return func_graph_parent_->parent_analysis()[fg];
|
return func_graph_parent_->parent_analysis()[fg];
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const {
|
FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
MS_EXCEPTION_IF_NULL(children_);
|
MS_EXCEPTION_IF_NULL(children_);
|
||||||
MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
|
MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
|
||||||
|
@ -137,7 +137,7 @@ FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const {
|
||||||
return children_->children_analysis()[fg];
|
return children_->children_analysis()[fg];
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const {
|
FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
MS_EXCEPTION_IF_NULL(scopes_);
|
MS_EXCEPTION_IF_NULL(scopes_);
|
||||||
MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString();
|
MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString();
|
||||||
|
@ -146,19 +146,19 @@ FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const {
|
||||||
return scopes_->scope_analysis()[fg];
|
return scopes_->scope_analysis()[fg];
|
||||||
}
|
}
|
||||||
|
|
||||||
FVTotalMap& FuncGraphManager::free_variables_total() const {
|
FVTotalMap &FuncGraphManager::free_variables_total() const {
|
||||||
MS_EXCEPTION_IF_NULL(free_variables_total_);
|
MS_EXCEPTION_IF_NULL(free_variables_total_);
|
||||||
free_variables_total_->Recompute();
|
free_variables_total_->Recompute();
|
||||||
return free_variables_total_->fv_total_analysis();
|
return free_variables_total_->fv_total_analysis();
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphSet& FuncGraphManager::func_graphs_used_total(const FuncGraphPtr& fg) const {
|
FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
|
MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
|
||||||
func_graphs_used_total_->Recompute(fg);
|
func_graphs_used_total_->Recompute(fg);
|
||||||
return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
|
return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const {
|
bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
recursive_->Recompute(fg);
|
recursive_->Recompute(fg);
|
||||||
if (recursive_->recursive_analysis().count(fg) == 0) {
|
if (recursive_->recursive_analysis().count(fg) == 0) {
|
||||||
|
@ -168,7 +168,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const {
|
||||||
return recursive_->recursive_analysis()[fg];
|
return recursive_->recursive_analysis()[fg];
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr& fg) const {
|
std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
if (recursive(fg)) {
|
if (recursive(fg)) {
|
||||||
if (!recursive_->recursive_map().count(fg)) {
|
if (!recursive_->recursive_map().count(fg)) {
|
||||||
|
@ -185,7 +185,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr& fg) const {
|
bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const {
|
||||||
MS_EXCEPTION_IF_NULL(j_total_);
|
MS_EXCEPTION_IF_NULL(j_total_);
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
j_total_->Recompute(fg);
|
j_total_->Recompute(fg);
|
||||||
|
@ -225,10 +225,10 @@ void FuncGraphManager::Clear() {
|
||||||
signals_->InvalidateComputer();
|
signals_->InvalidateComputer();
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) {
|
void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) {
|
||||||
MS_LOG(DEBUG) << "Start keep roots";
|
MS_LOG(DEBUG) << "Start keep roots";
|
||||||
bool root_exist = false;
|
bool root_exist = false;
|
||||||
for (auto& item : func_graphs) {
|
for (auto &item : func_graphs) {
|
||||||
if (roots_.contains(item)) {
|
if (roots_.contains(item)) {
|
||||||
root_exist = true;
|
root_exist = true;
|
||||||
break;
|
break;
|
||||||
|
@ -245,17 +245,17 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) {
|
||||||
roots = roots_;
|
roots = roots_;
|
||||||
} else {
|
} else {
|
||||||
roots_.clear();
|
roots_.clear();
|
||||||
for (auto& item : roots) {
|
for (auto &item : roots) {
|
||||||
AddFuncGraph(item, true);
|
AddFuncGraph(item, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphSet keep;
|
FuncGraphSet keep;
|
||||||
for (auto& item : roots) {
|
for (auto &item : roots) {
|
||||||
MS_LOG(DEBUG) << "roots: " << item->ToString();
|
MS_LOG(DEBUG) << "roots: " << item->ToString();
|
||||||
keep.update(func_graphs_used_total(item));
|
keep.update(func_graphs_used_total(item));
|
||||||
#ifdef DEBUG
|
#ifdef DEBUG
|
||||||
for (auto& k : keep) {
|
for (auto &k : keep) {
|
||||||
MS_LOG(DEBUG) << "keep: " << k->ToString();
|
MS_LOG(DEBUG) << "keep: " << k->ToString();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -264,7 +264,7 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) {
|
||||||
} else {
|
} else {
|
||||||
Clear();
|
Clear();
|
||||||
FuncGraphSet roots(func_graphs);
|
FuncGraphSet roots(func_graphs);
|
||||||
for (auto& item : roots) {
|
for (auto &item : roots) {
|
||||||
AddFuncGraph(item, true);
|
AddFuncGraph(item, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -276,7 +276,7 @@ void FuncGraphManager::RemoveRoots() {
|
||||||
MaybeDropFuncGraphs(func_graphs_, true);
|
MaybeDropFuncGraphs(func_graphs_, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) {
|
void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
if (is_manage_) {
|
if (is_manage_) {
|
||||||
if (fg->manager() != nullptr && (&(*fg->manager()) != this)) {
|
if (fg->manager() != nullptr && (&(*fg->manager()) != this)) {
|
||||||
|
@ -288,7 +288,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) {
|
||||||
func_graphs_.add(fg);
|
func_graphs_.add(fg);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users) {
|
void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) {
|
||||||
FuncGraphSet todo(func_graphs);
|
FuncGraphSet todo(func_graphs);
|
||||||
std::set<FuncGraphPtr> dropped;
|
std::set<FuncGraphPtr> dropped;
|
||||||
// int count = 0;
|
// int count = 0;
|
||||||
|
@ -301,7 +301,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(func_graph_users_);
|
MS_EXCEPTION_IF_NULL(func_graph_users_);
|
||||||
auto& users = func_graph_users_->count_func_graphs_map()[func_graph];
|
auto &users = func_graph_users_->count_func_graphs_map()[func_graph];
|
||||||
if (!users.empty() && !ignore_users) {
|
if (!users.empty() && !ignore_users) {
|
||||||
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
|
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
|
||||||
continue;
|
continue;
|
||||||
|
@ -315,7 +315,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool
|
||||||
todo.update(MaybeDropNodes(return_vec));
|
todo.update(MaybeDropNodes(return_vec));
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(signals_);
|
MS_EXCEPTION_IF_NULL(signals_);
|
||||||
for (auto& fg : dropped) {
|
for (auto &fg : dropped) {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
signals_->DropFuncGraph(fg);
|
signals_->DropFuncGraph(fg);
|
||||||
all_nodes_.difference_update(fg->parameters());
|
all_nodes_.difference_update(fg->parameters());
|
||||||
|
@ -331,7 +331,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
|
||||||
MS_EXCEPTION_IF_NULL(inp);
|
MS_EXCEPTION_IF_NULL(inp);
|
||||||
if (direction == kDecEdge) {
|
if (direction == kDecEdge) {
|
||||||
MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString();
|
MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString();
|
||||||
auto& users_node = node_users_[inp];
|
auto &users_node = node_users_[inp];
|
||||||
if (!users_node.contains(make_pair(node, index))) {
|
if (!users_node.contains(make_pair(node, index))) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -346,26 +346,26 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
|
||||||
MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
|
MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
|
||||||
AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
|
AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
|
||||||
}
|
}
|
||||||
auto& users_node = node_users_[inp];
|
auto &users_node = node_users_[inp];
|
||||||
users_node.add(make_pair(node, index));
|
users_node.add(make_pair(node, index));
|
||||||
MS_EXCEPTION_IF_NULL(signals_);
|
MS_EXCEPTION_IF_NULL(signals_);
|
||||||
signals_->AddEdge(node, index, inp);
|
signals_->AddEdge(node, index, inp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction) {
|
void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (auto& inp : cnode->inputs()) {
|
for (auto &inp : cnode->inputs()) {
|
||||||
ProcessEdge(cnode, index, inp, direction);
|
ProcessEdge(cnode, index, inp, direction);
|
||||||
++index;
|
++index;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) {
|
IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) {
|
||||||
if (all_nodes_.contains(node)) {
|
if (all_nodes_.contains(node)) {
|
||||||
return EXCLUDE;
|
return EXCLUDE;
|
||||||
} else {
|
} else {
|
||||||
|
@ -373,9 +373,9 @@ IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) {
|
void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
|
||||||
AnfNodeSet acq;
|
AnfNodeSet acq;
|
||||||
for (auto& node : nodes) {
|
for (auto &node : nodes) {
|
||||||
std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1);
|
std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1);
|
||||||
|
|
||||||
AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit));
|
AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit));
|
||||||
|
@ -384,7 +384,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) {
|
||||||
acq.update(new_nodes);
|
acq.update(new_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& node : acq) {
|
for (auto &node : acq) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
FuncGraphPtr fg = node->func_graph();
|
FuncGraphPtr fg = node->func_graph();
|
||||||
if (fg != nullptr) {
|
if (fg != nullptr) {
|
||||||
|
@ -395,7 +395,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& nodes) {
|
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
|
||||||
AnfNodeSet nodes_ordered(nodes);
|
AnfNodeSet nodes_ordered(nodes);
|
||||||
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
|
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
|
||||||
MS_EXCEPTION_IF_NULL(signals_);
|
MS_EXCEPTION_IF_NULL(signals_);
|
||||||
|
@ -406,7 +406,7 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>&
|
||||||
if (!all_nodes_.contains(node)) {
|
if (!all_nodes_.contains(node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
AnfNodeIndexSet& users = node_users_[node];
|
AnfNodeIndexSet &users = node_users_[node];
|
||||||
|
|
||||||
std::vector<AnfNodePtr> parameters;
|
std::vector<AnfNodePtr> parameters;
|
||||||
if (!users.empty() ||
|
if (!users.empty() ||
|
||||||
|
@ -431,13 +431,13 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>&
|
||||||
return func_graphs_to_check;
|
return func_graphs_to_check;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters) {
|
void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters) {
|
||||||
auto tr = Transact();
|
auto tr = Transact();
|
||||||
tr.SetParameters(fg, parameters);
|
tr.SetParameters(fg, parameters);
|
||||||
tr.Commit();
|
tr.Commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) {
|
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||||
auto tr = Transact();
|
auto tr = Transact();
|
||||||
bool success = tr.Replace(old_node, new_node);
|
bool success = tr.Replace(old_node, new_node);
|
||||||
if (success) {
|
if (success) {
|
||||||
|
@ -446,13 +446,13 @@ bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value) {
|
void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
|
||||||
auto tr = Transact();
|
auto tr = Transact();
|
||||||
tr.SetEdge(node, index, value);
|
tr.SetEdge(node, index, value);
|
||||||
tr.Commit();
|
tr.Commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope) {
|
void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) {
|
||||||
AnfNodePtr source_return = source->get_return();
|
AnfNodePtr source_return = source->get_return();
|
||||||
AnfNodePtr source_output = source->output();
|
AnfNodePtr source_output = source->output();
|
||||||
AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
|
AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
|
||||||
|
@ -466,23 +466,23 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t
|
||||||
(void)all_nodes_.erase(source_return);
|
(void)all_nodes_.erase(source_return);
|
||||||
(void)node_users_.erase(source_return);
|
(void)node_users_.erase(source_return);
|
||||||
signals_->DropNode(source_return);
|
signals_->DropNode(source_return);
|
||||||
for (auto& node : source->nodes()) {
|
for (auto &node : source->nodes()) {
|
||||||
node->set_func_graph(target);
|
node->set_func_graph(target);
|
||||||
if (node->scope() == kDefaultScope) {
|
if (node->scope() == kDefaultScope) {
|
||||||
node->set_scope(scope);
|
node->set_scope(scope);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& used : source->func_graphs_used()) {
|
for (auto &used : source->func_graphs_used()) {
|
||||||
(void)func_graph_users_->Inc(used.first, target, used.second);
|
(void)func_graph_users_->Inc(used.first, target, used.second);
|
||||||
(void)this->func_graph_users()[used.first].erase(source);
|
(void)this->func_graph_users()[used.first].erase(source);
|
||||||
}
|
}
|
||||||
for (auto& child : this->func_graph_child_direct()[source]) {
|
for (auto &child : this->func_graph_child_direct()[source]) {
|
||||||
(void)func_graph_parents_direct_->Inc(child.first, target, child.second);
|
(void)func_graph_parents_direct_->Inc(child.first, target, child.second);
|
||||||
(void)this->func_graph_parents_direct()[child.first].erase(source);
|
(void)this->func_graph_parents_direct()[child.first].erase(source);
|
||||||
}
|
}
|
||||||
for (auto& fv_count : this->free_variables_direct()[source]) {
|
for (auto &fv_count : this->free_variables_direct()[source]) {
|
||||||
auto fv_g = fv_count.first->func_graph();
|
auto fv_g = fv_count.first->func_graph();
|
||||||
auto& count_on_g = this->func_graph_child_direct()[fv_g];
|
auto &count_on_g = this->func_graph_child_direct()[fv_g];
|
||||||
auto pair = count_on_g.find(source);
|
auto pair = count_on_g.find(source);
|
||||||
if (fv_g != target && pair != count_on_g.end()) {
|
if (fv_g != target && pair != count_on_g.end()) {
|
||||||
(void)func_graph_child_direct_->Inc(fv_g, target, pair->second);
|
(void)func_graph_child_direct_->Inc(fv_g, target, pair->second);
|
||||||
|
@ -504,9 +504,9 @@ FuncGraphTransaction FuncGraphManager::Transact() {
|
||||||
return tr;
|
return tr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges,
|
void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges,
|
||||||
EdgeTupleCounter* rm_edges, Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms) {
|
EdgeTupleCounter *rm_edges, Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms) {
|
||||||
for (auto& iter : changes) {
|
for (auto &iter : changes) {
|
||||||
auto operation = iter.op;
|
auto operation = iter.op;
|
||||||
auto args = iter.args;
|
auto args = iter.args;
|
||||||
if (operation == Change::kTxSetEdge) {
|
if (operation == Change::kTxSetEdge) {
|
||||||
|
@ -521,10 +521,10 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl
|
||||||
auto param = args.cast<ArgsOfSetParams>();
|
auto param = args.cast<ArgsOfSetParams>();
|
||||||
MS_EXCEPTION_IF_NULL(param.func_graph);
|
MS_EXCEPTION_IF_NULL(param.func_graph);
|
||||||
auto old_parameters = param.func_graph->parameters();
|
auto old_parameters = param.func_graph->parameters();
|
||||||
for (auto& p : param.params) {
|
for (auto &p : param.params) {
|
||||||
(*adds)[p] += 1;
|
(*adds)[p] += 1;
|
||||||
}
|
}
|
||||||
for (auto& p : old_parameters) {
|
for (auto &p : old_parameters) {
|
||||||
(*rms)[p] += 1;
|
(*rms)[p] += 1;
|
||||||
}
|
}
|
||||||
param.func_graph->set_parameters(param.params);
|
param.func_graph->set_parameters(param.params);
|
||||||
|
@ -532,7 +532,7 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
|
void FuncGraphManager::CommitChanges(const std::vector<Change> &changes) {
|
||||||
EdgeTupleCounter add_edges;
|
EdgeTupleCounter add_edges;
|
||||||
EdgeTupleCounter rm_edges;
|
EdgeTupleCounter rm_edges;
|
||||||
Counter<AnfNodePtr> adds;
|
Counter<AnfNodePtr> adds;
|
||||||
|
@ -540,7 +540,7 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
|
||||||
ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms);
|
ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms);
|
||||||
|
|
||||||
auto sub_edges = add_edges - rm_edges;
|
auto sub_edges = add_edges - rm_edges;
|
||||||
for (auto& iter : sub_edges) {
|
for (auto &iter : sub_edges) {
|
||||||
auto root_node = iter.first.first;
|
auto root_node = iter.first.first;
|
||||||
int index = iter.first.second.first;
|
int index = iter.first.second.first;
|
||||||
auto new_node = iter.first.second.second;
|
auto new_node = iter.first.second.second;
|
||||||
|
@ -550,12 +550,12 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
|
||||||
auto sub_nodes = adds - rms;
|
auto sub_nodes = adds - rms;
|
||||||
std::vector<AnfNodePtr> nodes;
|
std::vector<AnfNodePtr> nodes;
|
||||||
(void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes),
|
(void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes),
|
||||||
[](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; });
|
[](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
|
||||||
|
|
||||||
AcquireNodes(nodes);
|
AcquireNodes(nodes);
|
||||||
|
|
||||||
auto sub_edges_reverse = rm_edges - add_edges;
|
auto sub_edges_reverse = rm_edges - add_edges;
|
||||||
for (auto& iter : sub_edges_reverse) {
|
for (auto &iter : sub_edges_reverse) {
|
||||||
auto root_node = iter.first.first;
|
auto root_node = iter.first.first;
|
||||||
int index = iter.first.second.first;
|
int index = iter.first.second.first;
|
||||||
auto old_node = iter.first.second.second;
|
auto old_node = iter.first.second.second;
|
||||||
|
@ -566,17 +566,17 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
|
||||||
std::vector<AnfNodePtr> nodes_reverse;
|
std::vector<AnfNodePtr> nodes_reverse;
|
||||||
|
|
||||||
(void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse),
|
(void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse),
|
||||||
[](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; });
|
[](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
|
||||||
|
|
||||||
auto drop_func_graphs = MaybeDropNodes(nodes_reverse);
|
auto drop_func_graphs = MaybeDropNodes(nodes_reverse);
|
||||||
MaybeDropFuncGraphs(*drop_func_graphs);
|
MaybeDropFuncGraphs(*drop_func_graphs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params) {
|
void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms) {
|
||||||
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
|
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) {
|
bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||||
MS_EXCEPTION_IF_NULL(old_node);
|
MS_EXCEPTION_IF_NULL(old_node);
|
||||||
MS_EXCEPTION_IF_NULL(new_node);
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
FuncGraphPtr old_func_graph = old_node->func_graph();
|
FuncGraphPtr old_func_graph = old_node->func_graph();
|
||||||
|
@ -585,14 +585,14 @@ bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr&
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto users = manager_->node_users()[old_node];
|
auto users = manager_->node_users()[old_node];
|
||||||
for (auto& node : users) {
|
for (auto &node : users) {
|
||||||
SetEdge(node.first, node.second, new_node);
|
SetEdge(node.first, node.second, new_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphTransaction::SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v) {
|
void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) {
|
||||||
if (k < 0) {
|
if (k < 0) {
|
||||||
MS_LOG(EXCEPTION) << "Invalid value k = " << k;
|
MS_LOG(EXCEPTION) << "Invalid value k = " << k;
|
||||||
}
|
}
|
||||||
|
@ -610,7 +610,7 @@ void FuncGraphTransaction::Commit() {
|
||||||
manager_->CommitChanges(changes);
|
manager_->CommitChanges(changes);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager)
|
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
|
||||||
: manager_(manager), include_func_graph_none_(false) {
|
: manager_(manager), include_func_graph_none_(false) {
|
||||||
manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph);
|
manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph);
|
||||||
manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph);
|
manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph);
|
||||||
|
@ -619,7 +619,7 @@ FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager)
|
||||||
manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode);
|
manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodesCollector::NodesCollector(const FuncGraphManager* const m) : DepCollector(m), nodes_analysis_() {
|
NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() {
|
||||||
include_func_graph_none_ = true;
|
include_func_graph_none_ = true;
|
||||||
nodes_analysis_[nullptr] = AnfNodeSet();
|
nodes_analysis_[nullptr] = AnfNodeSet();
|
||||||
|
|
||||||
|
@ -646,7 +646,7 @@ void NodesCollector::OnDropNode(AnfNodePtr n) {
|
||||||
|
|
||||||
void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
// change the owner of node except for the src's return node
|
// change the owner of node except for the src's return node
|
||||||
for (auto& it : nodes_analysis_[src]) {
|
for (auto &it : nodes_analysis_[src]) {
|
||||||
nodes_analysis_[dst].add(it);
|
nodes_analysis_[dst].add(it);
|
||||||
}
|
}
|
||||||
(void)nodes_analysis_.erase(src);
|
(void)nodes_analysis_.erase(src);
|
||||||
|
@ -654,15 +654,15 @@ void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
|
|
||||||
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
|
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
|
||||||
|
|
||||||
DepCollector::DepCollector(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) {
|
DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector);
|
manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
|
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
|
||||||
|
|
||||||
bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) {
|
bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
|
||||||
auto& d = count_nodes_map_[func_graph];
|
auto &d = count_nodes_map_[func_graph];
|
||||||
if (d.count(key) == 0) {
|
if (d.count(key) == 0) {
|
||||||
d[key] = count;
|
d[key] = count;
|
||||||
return true;
|
return true;
|
||||||
|
@ -672,9 +672,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodeP
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) {
|
bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto& d = count_nodes_map_[func_graph];
|
auto &d = count_nodes_map_[func_graph];
|
||||||
if (d.count(key) != 0) {
|
if (d.count(key) != 0) {
|
||||||
if (d[key] == count) {
|
if (d[key] == count) {
|
||||||
(void)d.erase(key);
|
(void)d.erase(key);
|
||||||
|
@ -690,7 +690,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodeP
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count) {
|
bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) {
|
||||||
if (count > 0) {
|
if (count > 0) {
|
||||||
return Inc(func_graph, key, count);
|
return Inc(func_graph, key, count);
|
||||||
} else if (count < 0) {
|
} else if (count < 0) {
|
||||||
|
@ -701,8 +701,8 @@ bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodeP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) {
|
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
|
||||||
auto& d = count_func_graphs_map_[func_graph];
|
auto &d = count_func_graphs_map_[func_graph];
|
||||||
if (d.count(key) == 0) {
|
if (d.count(key) == 0) {
|
||||||
d[key] = count;
|
d[key] = count;
|
||||||
return true;
|
return true;
|
||||||
|
@ -712,8 +712,8 @@ bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGr
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) {
|
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
|
||||||
auto& d = count_func_graphs_map_[func_graph];
|
auto &d = count_func_graphs_map_[func_graph];
|
||||||
if (d.count(key) != 0) {
|
if (d.count(key) != 0) {
|
||||||
if (d[key] == count) {
|
if (d[key] == count) {
|
||||||
(void)d.erase(key);
|
(void)d.erase(key);
|
||||||
|
@ -729,7 +729,7 @@ bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGr
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count) {
|
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
|
||||||
if (count > 0) {
|
if (count > 0) {
|
||||||
return Inc(func_graph, key, count);
|
return Inc(func_graph, key, count);
|
||||||
} else if (count < 0) {
|
} else if (count < 0) {
|
||||||
|
@ -748,7 +748,7 @@ void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgePr
|
||||||
}
|
}
|
||||||
|
|
||||||
void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
for (auto& it : count_nodes_map_[src]) {
|
for (auto &it : count_nodes_map_[src]) {
|
||||||
(void)Inc(dst, it.first, it.second);
|
(void)Inc(dst, it.first, it.second);
|
||||||
}
|
}
|
||||||
(void)count_nodes_map_.erase(src);
|
(void)count_nodes_map_.erase(src);
|
||||||
|
@ -762,7 +762,7 @@ void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, Ed
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
for (auto& it : count_nodes_map_[src]) {
|
for (auto &it : count_nodes_map_[src]) {
|
||||||
(void)Inc(dst, it.first, it.second);
|
(void)Inc(dst, it.first, it.second);
|
||||||
}
|
}
|
||||||
(void)count_nodes_map_.erase(src);
|
(void)count_nodes_map_.erase(src);
|
||||||
|
@ -779,7 +779,7 @@ void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProc
|
||||||
}
|
}
|
||||||
|
|
||||||
void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
for (auto& it : count_nodes_map_[src]) {
|
for (auto &it : count_nodes_map_[src]) {
|
||||||
FuncGraphPtr fg2 = it.first->func_graph();
|
FuncGraphPtr fg2 = it.first->func_graph();
|
||||||
if (fg2 != dst) {
|
if (fg2 != dst) {
|
||||||
(void)Inc(dst, it.first, it.second);
|
(void)Inc(dst, it.first, it.second);
|
||||||
|
@ -788,7 +788,7 @@ void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
(void)count_nodes_map_.erase(src);
|
(void)count_nodes_map_.erase(src);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FuncGraphPtr ParentProxy(const FuncGraphPtr& fg) {
|
static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
|
||||||
FuncGraphPtr gn = std::make_shared<FuncGraph>();
|
FuncGraphPtr gn = std::make_shared<FuncGraph>();
|
||||||
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
|
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
|
||||||
return gn;
|
return gn;
|
||||||
|
@ -805,7 +805,7 @@ void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeP
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
for (auto& it : count_func_graphs_map_[src]) {
|
for (auto &it : count_func_graphs_map_[src]) {
|
||||||
FuncGraphPtr fg = it.first;
|
FuncGraphPtr fg = it.first;
|
||||||
if (fg != dst) {
|
if (fg != dst) {
|
||||||
(void)Inc(dst, fg, it.second);
|
(void)Inc(dst, fg, it.second);
|
||||||
|
@ -835,7 +835,7 @@ void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
for (auto& it : count_func_graphs_map_[src]) {
|
for (auto &it : count_func_graphs_map_[src]) {
|
||||||
if (it.first != dst) {
|
if (it.first != dst) {
|
||||||
(void)Inc(dst, it.first, it.second);
|
(void)Inc(dst, it.first, it.second);
|
||||||
}
|
}
|
||||||
|
@ -852,7 +852,7 @@ void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, Ed
|
||||||
|
|
||||||
void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
// all graph use in src need to change to dst, so meger the to dst use
|
// all graph use in src need to change to dst, so meger the to dst use
|
||||||
for (auto& it : count_func_graphs_map_[src]) {
|
for (auto &it : count_func_graphs_map_[src]) {
|
||||||
(void)Inc(dst, it.first, it.second);
|
(void)Inc(dst, it.first, it.second);
|
||||||
}
|
}
|
||||||
(void)count_func_graphs_map_[dst].erase(src);
|
(void)count_func_graphs_map_[dst].erase(src);
|
||||||
|
@ -879,7 +879,7 @@ void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
for (auto& it : count_nodes_map_[src]) {
|
for (auto &it : count_nodes_map_[src]) {
|
||||||
(void)Inc(dst, it.first, it.second);
|
(void)Inc(dst, it.first, it.second);
|
||||||
}
|
}
|
||||||
(void)count_nodes_map_.erase(src);
|
(void)count_nodes_map_.erase(src);
|
||||||
|
@ -895,13 +895,13 @@ void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp,
|
||||||
|
|
||||||
void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||||
// all graph use in src need to change to dst, so meger the to dst use
|
// all graph use in src need to change to dst, so meger the to dst use
|
||||||
for (auto& it : count_func_graphs_map_[src]) {
|
for (auto &it : count_func_graphs_map_[src]) {
|
||||||
(void)Inc(dst, it.first, it.second);
|
(void)Inc(dst, it.first, it.second);
|
||||||
}
|
}
|
||||||
(void)count_func_graphs_map_.erase(src);
|
(void)count_func_graphs_map_.erase(src);
|
||||||
}
|
}
|
||||||
|
|
||||||
DepComputer::DepComputer(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) {
|
DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
|
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
|
||||||
validate_ = false;
|
validate_ = false;
|
||||||
|
@ -914,20 +914,20 @@ void DepComputer::Recompute() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DepComputer::Recompute(const FuncGraphPtr& fg) {
|
void DepComputer::Recompute(const FuncGraphPtr &fg) {
|
||||||
if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
|
if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
|
||||||
RealRecompute(fg);
|
RealRecompute(fg);
|
||||||
func_graphs_validate_[fg] = true;
|
func_graphs_validate_[fg] = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) {
|
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
|
||||||
if (path == nullptr || path->contains(fg)) {
|
if (path == nullptr || path->contains(fg)) {
|
||||||
return std::make_shared<FuncGraphSet>();
|
return std::make_shared<FuncGraphSet>();
|
||||||
}
|
}
|
||||||
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
|
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
|
||||||
FuncGraphToFuncGraphCounterMap& deps = *all_parents_direct_;
|
FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_;
|
||||||
for (auto& dep : deps[fg]) {
|
for (auto &dep : deps[fg]) {
|
||||||
MS_EXCEPTION_IF_NULL(dep.first);
|
MS_EXCEPTION_IF_NULL(dep.first);
|
||||||
auto proxy = dep.first->transforms().find("proxy");
|
auto proxy = dep.first->transforms().find("proxy");
|
||||||
if (proxy != dep.first->transforms().end()) {
|
if (proxy != dep.first->transforms().end()) {
|
||||||
|
@ -950,7 +950,7 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size();
|
MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool set_len_compare(const FuncGraphSetPair& lhs, const FuncGraphSetPair& rhs) {
|
bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
|
||||||
auto l1 = lhs.second.size();
|
auto l1 = lhs.second.size();
|
||||||
auto l2 = rhs.second.size();
|
auto l2 = rhs.second.size();
|
||||||
return l1 < l2;
|
return l1 < l2;
|
||||||
|
@ -970,9 +970,9 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
} else {
|
} else {
|
||||||
// return nearest parent as parent
|
// return nearest parent as parent
|
||||||
FuncGraphSet deps_copy(deps);
|
FuncGraphSet deps_copy(deps);
|
||||||
for (auto& dep : deps) {
|
for (auto &dep : deps) {
|
||||||
auto parent_deps = this->manager_->func_graph_parents_total(dep);
|
auto parent_deps = this->manager_->func_graph_parents_total(dep);
|
||||||
for (auto& p_d : parent_deps) {
|
for (auto &p_d : parent_deps) {
|
||||||
if (deps_copy.count(p_d)) {
|
if (deps_copy.count(p_d)) {
|
||||||
(void)deps_copy.erase(p_d);
|
(void)deps_copy.erase(p_d);
|
||||||
}
|
}
|
||||||
|
@ -988,7 +988,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
|
void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
auto used_fg_total = manager_->func_graphs_used_total(fg);
|
auto used_fg_total = manager_->func_graphs_used_total(fg);
|
||||||
for (auto& used_fg : used_fg_total) {
|
for (auto &used_fg : used_fg_total) {
|
||||||
if (manager_->parent(used_fg) == fg) {
|
if (manager_->parent(used_fg) == fg) {
|
||||||
children_analysis_[fg].add(used_fg);
|
children_analysis_[fg].add(used_fg);
|
||||||
}
|
}
|
||||||
|
@ -997,11 +997,11 @@ void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
|
|
||||||
void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
|
void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
auto& children = manager_->children(fg);
|
auto &children = manager_->children(fg);
|
||||||
|
|
||||||
scope_analysis_[fg] = FuncGraphSet();
|
scope_analysis_[fg] = FuncGraphSet();
|
||||||
scope_analysis_[fg].add(fg);
|
scope_analysis_[fg].add(fg);
|
||||||
for (auto& child : children) {
|
for (auto &child : children) {
|
||||||
scope_analysis_[fg].add(child);
|
scope_analysis_[fg].add(child);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1010,20 +1010,20 @@ void FVTotalComputer::RealRecompute() {
|
||||||
auto manager = DepComputer::manager_;
|
auto manager = DepComputer::manager_;
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
|
||||||
for (auto& fg : manager->func_graphs()) {
|
for (auto &fg : manager->func_graphs()) {
|
||||||
fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
|
fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
|
||||||
count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>();
|
count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>();
|
||||||
count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>();
|
count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& fg : manager->func_graphs()) {
|
for (auto &fg : manager->func_graphs()) {
|
||||||
AnfNodeCounterMap items = manager->free_variables_direct()[fg];
|
AnfNodeCounterMap items = manager->free_variables_direct()[fg];
|
||||||
for (auto& iter : items) {
|
for (auto &iter : items) {
|
||||||
auto curr = fg;
|
auto curr = fg;
|
||||||
while (curr) {
|
while (curr) {
|
||||||
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
|
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
|
||||||
curr = manager->parent(curr);
|
curr = manager->parent(curr);
|
||||||
const AnfNodeSet& nodes = manager->nodes()[curr];
|
const AnfNodeSet &nodes = manager->nodes()[curr];
|
||||||
if (nodes.contains(iter.first)) {
|
if (nodes.contains(iter.first)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1031,7 +1031,7 @@ void FVTotalComputer::RealRecompute() {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto items_fg = manager->func_graphs_used()[fg];
|
auto items_fg = manager->func_graphs_used()[fg];
|
||||||
for (auto& iter : items_fg) {
|
for (auto &iter : items_fg) {
|
||||||
auto p = manager->parent(iter.first);
|
auto p = manager->parent(iter.first);
|
||||||
if (p == nullptr) {
|
if (p == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -1043,13 +1043,13 @@ void FVTotalComputer::RealRecompute() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& fg : manager->func_graphs()) {
|
for (auto &fg : manager->func_graphs()) {
|
||||||
auto& fvp = count_nodes_map_[fg];
|
auto &fvp = count_nodes_map_[fg];
|
||||||
auto& fvg = count_func_graphs_map_[fg];
|
auto &fvg = count_func_graphs_map_[fg];
|
||||||
for (auto& item : fvp) {
|
for (auto &item : fvp) {
|
||||||
fv_total_analysis_[fg][item.first] = item.second;
|
fv_total_analysis_[fg][item.first] = item.second;
|
||||||
}
|
}
|
||||||
for (auto& item : fvg) {
|
for (auto &item : fvg) {
|
||||||
fv_total_analysis_[fg][item.first] = item.second;
|
fv_total_analysis_[fg][item.first] = item.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1057,15 +1057,15 @@ void FVTotalComputer::RealRecompute() {
|
||||||
|
|
||||||
void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
auto& used = this->manager_->func_graphs_used();
|
auto &used = this->manager_->func_graphs_used();
|
||||||
std::vector<FuncGraphPtr> todo;
|
std::vector<FuncGraphPtr> todo;
|
||||||
std::vector<FuncGraphPtr> todo_new;
|
std::vector<FuncGraphPtr> todo_new;
|
||||||
|
|
||||||
todo.push_back(fg);
|
todo.push_back(fg);
|
||||||
while (!todo.empty()) {
|
while (!todo.empty()) {
|
||||||
todo_new.clear();
|
todo_new.clear();
|
||||||
for (auto& gt : todo) {
|
for (auto > : todo) {
|
||||||
for (auto& item : used[gt]) {
|
for (auto &item : used[gt]) {
|
||||||
auto used_fg = item.first;
|
auto used_fg = item.first;
|
||||||
if (used_fg == fg) {
|
if (used_fg == fg) {
|
||||||
func_graph_used_total_analysis_[fg].add(used_fg);
|
func_graph_used_total_analysis_[fg].add(used_fg);
|
||||||
|
@ -1082,17 +1082,17 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CheckRecursive(const FuncGraphManager* const manager, const FuncGraphPtr& fg) {
|
bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
auto& used = manager->func_graphs_used();
|
auto &used = manager->func_graphs_used();
|
||||||
std::vector<FuncGraphPtr> todo;
|
std::vector<FuncGraphPtr> todo;
|
||||||
std::vector<FuncGraphPtr> todo_new;
|
std::vector<FuncGraphPtr> todo_new;
|
||||||
todo.push_back(fg);
|
todo.push_back(fg);
|
||||||
FuncGraphSet used_total;
|
FuncGraphSet used_total;
|
||||||
while (!todo.empty()) {
|
while (!todo.empty()) {
|
||||||
todo_new.clear();
|
todo_new.clear();
|
||||||
for (auto& gt : todo) {
|
for (auto > : todo) {
|
||||||
for (auto& item : used[gt]) {
|
for (auto &item : used[gt]) {
|
||||||
auto used_g = item.first;
|
auto used_g = item.first;
|
||||||
if (used_g == fg) {
|
if (used_g == fg) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -1112,7 +1112,7 @@ void RecursiveComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
|
this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace) {
|
void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) {
|
||||||
MS_EXCEPTION_IF_NULL(trace);
|
MS_EXCEPTION_IF_NULL(trace);
|
||||||
auto res = std::find(trace->begin(), trace->end(), fg);
|
auto res = std::find(trace->begin(), trace->end(), fg);
|
||||||
// find recursive
|
// find recursive
|
||||||
|
@ -1124,7 +1124,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
trace->push_back(fg);
|
trace->push_back(fg);
|
||||||
auto& used_fgs = manager_->func_graphs_used()[fg];
|
auto &used_fgs = manager_->func_graphs_used()[fg];
|
||||||
for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) {
|
for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) {
|
||||||
CheckRecursiveGraphs(iter->first, trace);
|
CheckRecursiveGraphs(iter->first, trace);
|
||||||
}
|
}
|
||||||
|
@ -1135,14 +1135,14 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) {
|
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
|
||||||
MS_EXCEPTION_IF_NULL(path);
|
MS_EXCEPTION_IF_NULL(path);
|
||||||
if (path->contains(fg)) {
|
if (path->contains(fg)) {
|
||||||
MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked";
|
MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
auto& func_graph_counter_map = manager_->func_graph_j_direct();
|
auto &func_graph_counter_map = manager_->func_graph_j_direct();
|
||||||
if (!func_graph_counter_map[fg].empty()) {
|
if (!func_graph_counter_map[fg].empty()) {
|
||||||
// check g1->J(fg)->g2->g cycle;
|
// check g1->J(fg)->g2->g cycle;
|
||||||
auto contains_j =
|
auto contains_j =
|
||||||
|
@ -1156,8 +1156,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt
|
||||||
path->add(fg);
|
path->add(fg);
|
||||||
|
|
||||||
// check if func graphs used contains J(func_graph);
|
// check if func graphs used contains J(func_graph);
|
||||||
auto& used = this->manager_->func_graphs_used();
|
auto &used = this->manager_->func_graphs_used();
|
||||||
for (auto& item : used[fg]) {
|
for (auto &item : used[fg]) {
|
||||||
auto used_g = item.first;
|
auto used_g = item.first;
|
||||||
if (SeekJ(used_g, path)) {
|
if (SeekJ(used_g, path)) {
|
||||||
MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString()
|
MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString()
|
||||||
|
|
|
@ -46,13 +46,13 @@ class FuncGraphManager;
|
||||||
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
|
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
|
||||||
|
|
||||||
struct AnfNodeIndexPairHasher {
|
struct AnfNodeIndexPairHasher {
|
||||||
std::size_t operator()(const std::pair<AnfNodePtr, int>& p1) const {
|
std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const {
|
||||||
return std::hash<const AnfNode*>{}(p1.first.get());
|
return std::hash<const AnfNode *>{}(p1.first.get());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AnfNodeIndexPairEqual {
|
struct AnfNodeIndexPairEqual {
|
||||||
bool operator()(const std::pair<AnfNodePtr, int>& lhs, const std::pair<AnfNodePtr, int>& rhs) const {
|
bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const {
|
||||||
return lhs == rhs;
|
return lhs == rhs;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -63,14 +63,14 @@ using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
|
||||||
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
|
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
|
||||||
using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>;
|
using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>;
|
||||||
struct EdgeTupleHasher {
|
struct EdgeTupleHasher {
|
||||||
std::size_t operator()(const EdgeTuple& p1) const {
|
std::size_t operator()(const EdgeTuple &p1) const {
|
||||||
return hash_combine({std::hash<AnfNode*>{}(p1.first.get()), std::hash<int>{}(p1.second.first),
|
return hash_combine({std::hash<AnfNode *>{}(p1.first.get()), std::hash<int>{}(p1.second.first),
|
||||||
std::hash<AnfNode*>{}(p1.second.second.get())});
|
std::hash<AnfNode *>{}(p1.second.second.get())});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct EdgeTupleEqual {
|
struct EdgeTupleEqual {
|
||||||
bool operator()(const EdgeTuple& lhs, const EdgeTuple& rhs) const {
|
bool operator()(const EdgeTuple &lhs, const EdgeTuple &rhs) const {
|
||||||
return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second;
|
return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -82,9 +82,9 @@ using EdgeTupleCounter = Counter<EdgeTuple, EdgeTupleHasher, EdgeTupleEqual>;
|
||||||
// FuncGraphManagerPtr: return created manager
|
// FuncGraphManagerPtr: return created manager
|
||||||
FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true);
|
FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true);
|
||||||
|
|
||||||
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage = true);
|
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true);
|
||||||
|
|
||||||
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs = {}, bool manage = true);
|
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true);
|
||||||
|
|
||||||
struct Signals {
|
struct Signals {
|
||||||
Signal<void(FuncGraphPtr)> AddFuncGraph;
|
Signal<void(FuncGraphPtr)> AddFuncGraph;
|
||||||
|
@ -106,7 +106,7 @@ using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNode
|
||||||
// analysis base class
|
// analysis base class
|
||||||
class FuncGraphAnalysis {
|
class FuncGraphAnalysis {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphAnalysis(const FuncGraphManager* const manager);
|
explicit FuncGraphAnalysis(const FuncGraphManager *const manager);
|
||||||
|
|
||||||
virtual ~FuncGraphAnalysis() { manager_ = nullptr; }
|
virtual ~FuncGraphAnalysis() { manager_ = nullptr; }
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ class FuncGraphAnalysis {
|
||||||
|
|
||||||
virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {}
|
virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {}
|
||||||
|
|
||||||
const FuncGraphManager* manager_;
|
const FuncGraphManager *manager_;
|
||||||
bool include_func_graph_none_;
|
bool include_func_graph_none_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;
|
||||||
// graphs analysis which compute in write, read needn't recompute
|
// graphs analysis which compute in write, read needn't recompute
|
||||||
class DepCollector : public FuncGraphAnalysis {
|
class DepCollector : public FuncGraphAnalysis {
|
||||||
public:
|
public:
|
||||||
explicit DepCollector(const FuncGraphManager* manager);
|
explicit DepCollector(const FuncGraphManager *manager);
|
||||||
~DepCollector() override = default;
|
~DepCollector() override = default;
|
||||||
|
|
||||||
void Reset() { ExtraReset(); }
|
void Reset() { ExtraReset(); }
|
||||||
|
@ -155,10 +155,10 @@ class DepCollector : public FuncGraphAnalysis {
|
||||||
|
|
||||||
class NodesCollector final : public DepCollector {
|
class NodesCollector final : public DepCollector {
|
||||||
public:
|
public:
|
||||||
explicit NodesCollector(const FuncGraphManager* m);
|
explicit NodesCollector(const FuncGraphManager *m);
|
||||||
~NodesCollector() override = default;
|
~NodesCollector() override = default;
|
||||||
|
|
||||||
const FuncGraphToAnfNodeMap& nodes_analysis() const { return nodes_analysis_; }
|
const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; }
|
||||||
size_t size() const override { return nodes_analysis_.size(); }
|
size_t size() const override { return nodes_analysis_.size(); }
|
||||||
void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); }
|
void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); }
|
||||||
|
|
||||||
|
@ -176,16 +176,16 @@ class NodesCollector final : public DepCollector {
|
||||||
|
|
||||||
class CounterFuncGraphCollector : public DepCollector {
|
class CounterFuncGraphCollector : public DepCollector {
|
||||||
public:
|
public:
|
||||||
explicit CounterFuncGraphCollector(const FuncGraphManager* m) : DepCollector(m) {}
|
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||||
~CounterFuncGraphCollector() override = default;
|
~CounterFuncGraphCollector() override = default;
|
||||||
FuncGraphToFuncGraphCounterMap& count_func_graphs_map() { return count_func_graphs_map_; }
|
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
|
||||||
// inherit from FuncGraphAnalysis
|
// inherit from FuncGraphAnalysis
|
||||||
size_t size() const override { return count_func_graphs_map_.size(); }
|
size_t size() const override { return count_func_graphs_map_.size(); }
|
||||||
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
|
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
|
||||||
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
|
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
|
||||||
bool Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count);
|
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
|
||||||
bool Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count);
|
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
|
||||||
bool Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count);
|
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
|
||||||
|
|
||||||
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
|
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
|
||||||
|
|
||||||
|
@ -195,17 +195,17 @@ class CounterFuncGraphCollector : public DepCollector {
|
||||||
|
|
||||||
class CounterAnfNodeCollector : public DepCollector {
|
class CounterAnfNodeCollector : public DepCollector {
|
||||||
public:
|
public:
|
||||||
explicit CounterAnfNodeCollector(const FuncGraphManager* m) : DepCollector(m) {}
|
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||||
~CounterAnfNodeCollector() override = default;
|
~CounterAnfNodeCollector() override = default;
|
||||||
FuncGraphToAnfNodeCounterMap& count_nodes_map() { return count_nodes_map_; }
|
FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; }
|
||||||
|
|
||||||
size_t size() const override { return count_nodes_map_.size(); }
|
size_t size() const override { return count_nodes_map_.size(); }
|
||||||
void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); }
|
void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); }
|
||||||
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
|
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
|
||||||
|
|
||||||
bool Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count);
|
bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
|
||||||
bool Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count);
|
bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
|
||||||
bool Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count);
|
bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
|
||||||
|
|
||||||
FuncGraphToAnfNodeCounterMap count_nodes_map_;
|
FuncGraphToAnfNodeCounterMap count_nodes_map_;
|
||||||
|
|
||||||
|
@ -215,7 +215,7 @@ class CounterAnfNodeCollector : public DepCollector {
|
||||||
|
|
||||||
class ValueNodesCollector final : public CounterAnfNodeCollector {
|
class ValueNodesCollector final : public CounterAnfNodeCollector {
|
||||||
public:
|
public:
|
||||||
explicit ValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {}
|
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||||
~ValueNodesCollector() override = default;
|
~ValueNodesCollector() override = default;
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ class ValueNodesCollector final : public CounterAnfNodeCollector {
|
||||||
|
|
||||||
class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
|
class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {}
|
explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||||
~FuncGraphValueNodesCollector() override = default;
|
~FuncGraphValueNodesCollector() override = default;
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||||
|
|
||||||
|
@ -235,7 +235,7 @@ class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
|
||||||
|
|
||||||
class FVDirectCollector final : public CounterAnfNodeCollector {
|
class FVDirectCollector final : public CounterAnfNodeCollector {
|
||||||
public:
|
public:
|
||||||
explicit FVDirectCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {}
|
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||||
~FVDirectCollector() override = default;
|
~FVDirectCollector() override = default;
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||||
|
|
||||||
|
@ -245,7 +245,7 @@ class FVDirectCollector final : public CounterAnfNodeCollector {
|
||||||
|
|
||||||
class FuncGraphChildDirect final : public CounterFuncGraphCollector {
|
class FuncGraphChildDirect final : public CounterFuncGraphCollector {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphChildDirect(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
|
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||||
|
|
||||||
~FuncGraphChildDirect() override = default;
|
~FuncGraphChildDirect() override = default;
|
||||||
|
@ -260,7 +260,7 @@ class FuncGraphChildDirect final : public CounterFuncGraphCollector {
|
||||||
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
|
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
|
||||||
class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
|
class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphParentsDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
|
explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||||
~FuncGraphParentsDirectCollector() override = default;
|
~FuncGraphParentsDirectCollector() override = default;
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||||
|
|
||||||
|
@ -271,7 +271,7 @@ class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
|
||||||
// graph's all used graphs: key is g, value is g used graph
|
// graph's all used graphs: key is g, value is g used graph
|
||||||
class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
|
class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphsUsedCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
|
explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||||
~FuncGraphsUsedCollector() override = default;
|
~FuncGraphsUsedCollector() override = default;
|
||||||
|
|
||||||
|
@ -282,7 +282,7 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
|
||||||
// graph's all user graphs: key is g, value is graphs who used g
|
// graph's all user graphs: key is g, value is graphs who used g
|
||||||
class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
|
class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphUsersCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
|
explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||||
~FuncGraphUsersCollector() override = default;
|
~FuncGraphUsersCollector() override = default;
|
||||||
|
|
||||||
|
@ -293,7 +293,7 @@ class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
|
||||||
// graph's all user cnodes: key is g, value is cnodes who used g
|
// graph's all user cnodes: key is g, value is cnodes who used g
|
||||||
class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
|
class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphUserNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {}
|
explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||||
~FuncGraphUserNodesCollector() override = default;
|
~FuncGraphUserNodesCollector() override = default;
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
|
||||||
|
|
||||||
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
|
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphJDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
|
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||||
void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override;
|
void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override;
|
||||||
~FuncGraphJDirectCollector() override = default;
|
~FuncGraphJDirectCollector() override = default;
|
||||||
|
|
||||||
|
@ -316,7 +316,7 @@ using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
|
||||||
// graphs analysis which need dynamic compute by DepCollector in each read
|
// graphs analysis which need dynamic compute by DepCollector in each read
|
||||||
class DepComputer : public FuncGraphAnalysis {
|
class DepComputer : public FuncGraphAnalysis {
|
||||||
public:
|
public:
|
||||||
explicit DepComputer(const FuncGraphManager* manager);
|
explicit DepComputer(const FuncGraphManager *manager);
|
||||||
~DepComputer() override = default;
|
~DepComputer() override = default;
|
||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
|
@ -329,11 +329,11 @@ class DepComputer : public FuncGraphAnalysis {
|
||||||
|
|
||||||
void Recompute();
|
void Recompute();
|
||||||
|
|
||||||
void Recompute(const FuncGraphPtr& fg);
|
void Recompute(const FuncGraphPtr &fg);
|
||||||
|
|
||||||
bool IsValidate() const { return validate_; }
|
bool IsValidate() const { return validate_; }
|
||||||
|
|
||||||
bool IsValidate(const FuncGraphPtr& fg) { return func_graphs_validate_[fg]; }
|
bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
|
||||||
|
|
||||||
void OnAddFuncGraph(FuncGraphPtr) final { Reset(); }
|
void OnAddFuncGraph(FuncGraphPtr) final { Reset(); }
|
||||||
|
|
||||||
|
@ -354,10 +354,10 @@ class DepComputer : public FuncGraphAnalysis {
|
||||||
// graph g's all direct or proxy parents
|
// graph g's all direct or proxy parents
|
||||||
class FuncGraphParentsTotalComputer final : public DepComputer {
|
class FuncGraphParentsTotalComputer final : public DepComputer {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphParentsTotalComputer(const FuncGraphManager* m) : DepComputer(m), all_parents_direct_(nullptr) {}
|
explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {}
|
||||||
~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; }
|
~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; }
|
||||||
|
|
||||||
FuncGraphToFuncGraphSetMap& func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
|
FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
|
||||||
|
|
||||||
size_t size() const override { return func_graph_parents_total_analysis_.size(); }
|
size_t size() const override { return func_graph_parents_total_analysis_.size(); }
|
||||||
|
|
||||||
|
@ -369,10 +369,10 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
|
||||||
void RealRecompute(FuncGraphPtr fg) override;
|
void RealRecompute(FuncGraphPtr fg) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FuncGraphSetPtr SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>());
|
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>());
|
||||||
// when SeekParents calls itself recursively, it can access these variables by class member
|
// when SeekParents calls itself recursively, it can access these variables by class member
|
||||||
// other than pass by formal parameters, it can save 1 parameter for SeekParents().
|
// other than pass by formal parameters, it can save 1 parameter for SeekParents().
|
||||||
FuncGraphToFuncGraphCounterMap* all_parents_direct_;
|
FuncGraphToFuncGraphCounterMap *all_parents_direct_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
|
using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
|
||||||
|
@ -380,10 +380,10 @@ using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
|
||||||
// graph's nearest parent in parents total
|
// graph's nearest parent in parents total
|
||||||
class ParentComputer final : public DepComputer {
|
class ParentComputer final : public DepComputer {
|
||||||
public:
|
public:
|
||||||
explicit ParentComputer(const FuncGraphManager* m) : DepComputer(m) {}
|
explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {}
|
||||||
~ParentComputer() override = default;
|
~ParentComputer() override = default;
|
||||||
|
|
||||||
FuncGraphToFuncGraphMap& parent_analysis() { return parent_analysis_; }
|
FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; }
|
||||||
|
|
||||||
size_t size() const override { return parent_analysis_.size(); }
|
size_t size() const override { return parent_analysis_.size(); }
|
||||||
|
|
||||||
|
@ -398,10 +398,10 @@ class ParentComputer final : public DepComputer {
|
||||||
// graph's children graph except self
|
// graph's children graph except self
|
||||||
class ChildrenComputer final : public DepComputer {
|
class ChildrenComputer final : public DepComputer {
|
||||||
public:
|
public:
|
||||||
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {}
|
explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {}
|
||||||
~ChildrenComputer() override = default;
|
~ChildrenComputer() override = default;
|
||||||
|
|
||||||
FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; }
|
FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; }
|
||||||
|
|
||||||
size_t size() const override { return children_analysis_.size(); }
|
size_t size() const override { return children_analysis_.size(); }
|
||||||
|
|
||||||
|
@ -416,10 +416,10 @@ class ChildrenComputer final : public DepComputer {
|
||||||
// graph's children graph include self
|
// graph's children graph include self
|
||||||
class ScopeComputer final : public DepComputer {
|
class ScopeComputer final : public DepComputer {
|
||||||
public:
|
public:
|
||||||
explicit ScopeComputer(const FuncGraphManager* m) : DepComputer(m) {}
|
explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {}
|
||||||
~ScopeComputer() override = default;
|
~ScopeComputer() override = default;
|
||||||
|
|
||||||
FuncGraphToFuncGraphSetMap& scope_analysis() { return scope_analysis_; }
|
FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; }
|
||||||
|
|
||||||
size_t size() const override { return scope_analysis_.size(); }
|
size_t size() const override { return scope_analysis_.size(); }
|
||||||
|
|
||||||
|
@ -435,11 +435,11 @@ using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash
|
||||||
|
|
||||||
class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector {
|
class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector {
|
||||||
public:
|
public:
|
||||||
explicit FVTotalComputer(const FuncGraphManager* m)
|
explicit FVTotalComputer(const FuncGraphManager *m)
|
||||||
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
|
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
|
||||||
~FVTotalComputer() override = default;
|
~FVTotalComputer() override = default;
|
||||||
|
|
||||||
FVTotalMap& fv_total_analysis() { return fv_total_analysis_; }
|
FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }
|
||||||
|
|
||||||
size_t size() const override { return fv_total_analysis_.size(); }
|
size_t size() const override { return fv_total_analysis_.size(); }
|
||||||
|
|
||||||
|
@ -453,10 +453,10 @@ class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector
|
||||||
|
|
||||||
class FuncGraphsUsedTotalComputer final : public DepComputer {
|
class FuncGraphsUsedTotalComputer final : public DepComputer {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphsUsedTotalComputer(const FuncGraphManager* m) : DepComputer(m) {}
|
explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
|
||||||
~FuncGraphsUsedTotalComputer() override = default;
|
~FuncGraphsUsedTotalComputer() override = default;
|
||||||
|
|
||||||
FuncGraphToFuncGraphSetMap& func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
|
FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
|
||||||
|
|
||||||
size_t size() const override { return func_graph_used_total_analysis_.size(); }
|
size_t size() const override { return func_graph_used_total_analysis_.size(); }
|
||||||
|
|
||||||
|
@ -473,13 +473,13 @@ using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGrap
|
||||||
|
|
||||||
class RecursiveComputer final : public DepComputer {
|
class RecursiveComputer final : public DepComputer {
|
||||||
public:
|
public:
|
||||||
explicit RecursiveComputer(const FuncGraphManager* m) : DepComputer(m) {}
|
explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {}
|
||||||
~RecursiveComputer() override = default;
|
~RecursiveComputer() override = default;
|
||||||
|
|
||||||
RecursiveMap& recursive_map() { return recursive_map_; }
|
RecursiveMap &recursive_map() { return recursive_map_; }
|
||||||
FuncGraphToBoolMap& recursive_analysis() { return recursive_analysis_; }
|
FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; }
|
||||||
|
|
||||||
void CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace);
|
void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace);
|
||||||
|
|
||||||
size_t size() const override { return recursive_analysis_.size(); }
|
size_t size() const override { return recursive_analysis_.size(); }
|
||||||
|
|
||||||
|
@ -497,10 +497,10 @@ class RecursiveComputer final : public DepComputer {
|
||||||
|
|
||||||
class FuncGraphJTotalComputer final : public DepComputer {
|
class FuncGraphJTotalComputer final : public DepComputer {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphJTotalComputer(const FuncGraphManager* m) : DepComputer(m) {}
|
explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
|
||||||
~FuncGraphJTotalComputer() override = default;
|
~FuncGraphJTotalComputer() override = default;
|
||||||
|
|
||||||
FuncGraphToBoolMap& j_total_analysis() { return j_total_analysis_; }
|
FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; }
|
||||||
|
|
||||||
size_t size() const override { return j_total_analysis_.size(); }
|
size_t size() const override { return j_total_analysis_.size(); }
|
||||||
|
|
||||||
|
@ -510,12 +510,12 @@ class FuncGraphJTotalComputer final : public DepComputer {
|
||||||
void ExtraReset() override { j_total_analysis_.clear(); }
|
void ExtraReset() override { j_total_analysis_.clear(); }
|
||||||
|
|
||||||
void RealRecompute(FuncGraphPtr fg) override;
|
void RealRecompute(FuncGraphPtr fg) override;
|
||||||
bool SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path);
|
bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path);
|
||||||
};
|
};
|
||||||
|
|
||||||
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage = true);
|
explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
|
||||||
~FuncGraphManager() {
|
~FuncGraphManager() {
|
||||||
if (is_manage_) {
|
if (is_manage_) {
|
||||||
RemoveRoots();
|
RemoveRoots();
|
||||||
|
@ -526,71 +526,71 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
||||||
void Init();
|
void Init();
|
||||||
void Clear();
|
void Clear();
|
||||||
void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false);
|
void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false);
|
||||||
void KeepRoots(const std::vector<FuncGraphPtr>& roots = {});
|
void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
|
||||||
void RemoveRoots();
|
void RemoveRoots();
|
||||||
void SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters);
|
void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters);
|
||||||
void MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users = false);
|
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
|
||||||
bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node);
|
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||||
void SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value);
|
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
|
||||||
void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope);
|
void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope);
|
||||||
|
|
||||||
FuncGraphTransaction Transact();
|
FuncGraphTransaction Transact();
|
||||||
void CommitChanges(const std::vector<Change>& changes);
|
void CommitChanges(const std::vector<Change> &changes);
|
||||||
|
|
||||||
bool IsManaged() const { return is_manage_; }
|
bool IsManaged() const { return is_manage_; }
|
||||||
|
|
||||||
const FuncGraphSet& roots() const { return roots_; }
|
const FuncGraphSet &roots() const { return roots_; }
|
||||||
|
|
||||||
const FuncGraphSet& func_graphs() const { return func_graphs_; }
|
const FuncGraphSet &func_graphs() const { return func_graphs_; }
|
||||||
|
|
||||||
AnfNodeSet& all_nodes() { return all_nodes_; }
|
AnfNodeSet &all_nodes() { return all_nodes_; }
|
||||||
|
|
||||||
NodeUsersMap& node_users() { return node_users_; }
|
NodeUsersMap &node_users() { return node_users_; }
|
||||||
|
|
||||||
FuncGraphToAnfNodeMap& nodes() const { return nodes_->nodes_analysis_; }
|
FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }
|
||||||
|
|
||||||
FuncGraphToAnfNodeCounterMap& valuenodes() const { return valuenodes_->count_nodes_map_; }
|
FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; }
|
||||||
|
|
||||||
FuncGraphToAnfNodeCounterMap& free_variables_direct() const { return free_variables_direct_->count_nodes_map_; }
|
FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; }
|
||||||
|
|
||||||
FuncGraphToAnfNodeCounterMap& func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; }
|
FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; }
|
||||||
|
|
||||||
FuncGraphToFuncGraphCounterMap& func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
|
FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
|
||||||
|
|
||||||
FuncGraphToFuncGraphCounterMap& func_graph_users() const { return func_graph_users_->count_func_graphs_map_; }
|
FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; }
|
||||||
|
|
||||||
FuncGraphToAnfNodeCounterMap& func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; }
|
FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; }
|
||||||
|
|
||||||
FuncGraphToFuncGraphCounterMap& func_graph_child_direct() const {
|
FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
|
||||||
return func_graph_child_direct_->count_func_graphs_map_;
|
return func_graph_child_direct_->count_func_graphs_map_;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphToFuncGraphCounterMap& func_graph_parents_direct() const {
|
FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const {
|
||||||
return func_graph_parents_direct_->count_func_graphs_map_;
|
return func_graph_parents_direct_->count_func_graphs_map_;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphToFuncGraphCounterMap& func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; }
|
FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; }
|
||||||
|
|
||||||
FVTotalMap& free_variables_total() const;
|
FVTotalMap &free_variables_total() const;
|
||||||
|
|
||||||
FuncGraphSet& func_graph_parents_total(const FuncGraphPtr& fg) const;
|
FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
|
||||||
|
|
||||||
FuncGraphSet& scopes(const FuncGraphPtr& fg) const;
|
FuncGraphSet &scopes(const FuncGraphPtr &fg) const;
|
||||||
|
|
||||||
FuncGraphPtr parent(const FuncGraphPtr& fg) const;
|
FuncGraphPtr parent(const FuncGraphPtr &fg) const;
|
||||||
|
|
||||||
FuncGraphSet& children(const FuncGraphPtr& fg) const;
|
FuncGraphSet &children(const FuncGraphPtr &fg) const;
|
||||||
|
|
||||||
FuncGraphSet& func_graphs_used_total(const FuncGraphPtr& fg) const;
|
FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const;
|
||||||
|
|
||||||
bool recursive(const FuncGraphPtr& fg) const;
|
bool recursive(const FuncGraphPtr &fg) const;
|
||||||
std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr& fg) const;
|
std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const;
|
||||||
|
|
||||||
bool func_graph_j_total(const FuncGraphPtr& fg) const;
|
bool func_graph_j_total(const FuncGraphPtr &fg) const;
|
||||||
|
|
||||||
std::shared_ptr<Signals> signals() const { return signals_; }
|
std::shared_ptr<Signals> signals() const { return signals_; }
|
||||||
|
|
||||||
IncludeType Limit(const AnfNodePtr& node);
|
IncludeType Limit(const AnfNodePtr &node);
|
||||||
|
|
||||||
// Static Analysis
|
// Static Analysis
|
||||||
NodeUsersMap node_users_;
|
NodeUsersMap node_users_;
|
||||||
|
@ -610,13 +610,13 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
||||||
std::shared_ptr<ParentComputer> func_graph_parent_;
|
std::shared_ptr<ParentComputer> func_graph_parent_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void AddIntoManaged(const FuncGraphPtr& fg);
|
void AddIntoManaged(const FuncGraphPtr &fg);
|
||||||
void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction);
|
void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction);
|
||||||
void ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction);
|
void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction);
|
||||||
void AcquireNodes(const std::vector<AnfNodePtr>& nodes);
|
void AcquireNodes(const std::vector<AnfNodePtr> &nodes);
|
||||||
FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr>& nodes);
|
FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes);
|
||||||
void ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges, EdgeTupleCounter* rm_edges,
|
void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges,
|
||||||
Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms);
|
Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms);
|
||||||
|
|
||||||
FuncGraphSet roots_; // managed roots
|
FuncGraphSet roots_; // managed roots
|
||||||
FuncGraphSet func_graphs_; // managed func graphs
|
FuncGraphSet func_graphs_; // managed func graphs
|
||||||
|
@ -637,7 +637,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
||||||
|
|
||||||
class FuncGraphTransaction {
|
class FuncGraphTransaction {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() {
|
explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager), changes_() {
|
||||||
MS_EXCEPTION_IF_NULL(manager_);
|
MS_EXCEPTION_IF_NULL(manager_);
|
||||||
if (!manager_->IsManaged()) {
|
if (!manager_->IsManaged()) {
|
||||||
MS_LOG(DEBUG) << "The manager is not managed yet";
|
MS_LOG(DEBUG) << "The manager is not managed yet";
|
||||||
|
@ -648,19 +648,19 @@ class FuncGraphTransaction {
|
||||||
~FuncGraphTransaction() { manager_ = nullptr; }
|
~FuncGraphTransaction() { manager_ = nullptr; }
|
||||||
|
|
||||||
// set parameters of a func graph
|
// set parameters of a func graph
|
||||||
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params);
|
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms);
|
||||||
|
|
||||||
// replace old_node with new_node
|
// replace old_node with new_node
|
||||||
bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node);
|
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||||
|
|
||||||
// set esge, i.e., declare setting node.inputs[key] to value.
|
// set esge, i.e., declare setting node.inputs[key] to value.
|
||||||
void SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v);
|
void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v);
|
||||||
|
|
||||||
// commit all changes
|
// commit all changes
|
||||||
void Commit();
|
void Commit();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FuncGraphManager* manager_;
|
FuncGraphManager *manager_;
|
||||||
std::vector<Change> changes_;
|
std::vector<Change> changes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -668,9 +668,9 @@ class FuncGraphTransaction {
|
||||||
struct ArgsOfSetParams {
|
struct ArgsOfSetParams {
|
||||||
FuncGraphPtr func_graph;
|
FuncGraphPtr func_graph;
|
||||||
std::vector<AnfNodePtr> params;
|
std::vector<AnfNodePtr> params;
|
||||||
bool operator==(const ArgsOfSetParams& other) const { return &other == this; }
|
bool operator==(const ArgsOfSetParams &other) const { return &other == this; }
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetParams&) {
|
friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetParams &) {
|
||||||
os << "[ArgsOfSetParams]";
|
os << "[ArgsOfSetParams]";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
@ -681,9 +681,9 @@ struct ArgsOfSetEdge {
|
||||||
CNodePtr root_node;
|
CNodePtr root_node;
|
||||||
AnfNodePtr new_node;
|
AnfNodePtr new_node;
|
||||||
size_t index;
|
size_t index;
|
||||||
bool operator==(const ArgsOfSetEdge& other) const { return &other == this; }
|
bool operator==(const ArgsOfSetEdge &other) const { return &other == this; }
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetEdge& other) {
|
friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetEdge &other) {
|
||||||
os << "[ArgsOfSetEdge]";
|
os << "[ArgsOfSetEdge]";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
@ -693,7 +693,7 @@ struct Change {
|
||||||
enum OpName { kTxSetParams, kTxSetEdge };
|
enum OpName { kTxSetParams, kTxSetEdge };
|
||||||
OpName op;
|
OpName op;
|
||||||
Any args;
|
Any args;
|
||||||
Change(OpName name, const Any& para) : op(name), args(para) {}
|
Change(OpName name, const Any ¶) : op(name), args(para) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -42,25 +42,25 @@ namespace mindspore {
|
||||||
// generate a graph corresponding to these types.
|
// generate a graph corresponding to these types.
|
||||||
class MetaFuncGraph : public FuncGraphBase {
|
class MetaFuncGraph : public FuncGraphBase {
|
||||||
public:
|
public:
|
||||||
explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); }
|
explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); }
|
||||||
|
|
||||||
~MetaFuncGraph() override = default;
|
~MetaFuncGraph() override = default;
|
||||||
|
|
||||||
MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase);
|
MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase);
|
||||||
abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node);
|
abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node);
|
||||||
// Return normalized versions of the arguments.
|
// Return normalized versions of the arguments.
|
||||||
// By default, this returns args unchanged.
|
// By default, this returns args unchanged.
|
||||||
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const {
|
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
|
||||||
return args_spec_list;
|
return args_spec_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<Signature>& signatures() const { return signatures_; }
|
const std::vector<Signature> &signatures() const { return signatures_; }
|
||||||
void set_signatures(const std::vector<Signature>& signatures) { signatures_ = signatures; }
|
void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
|
||||||
// Generate a Graph for the given abstract arguments.
|
// Generate a Graph for the given abstract arguments.
|
||||||
virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) {
|
virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
|
||||||
TypePtrList types;
|
TypePtrList types;
|
||||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
|
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
|
||||||
[](const AbstractBasePtr& arg) -> TypePtr {
|
[](const AbstractBasePtr &arg) -> TypePtr {
|
||||||
MS_EXCEPTION_IF_NULL(arg);
|
MS_EXCEPTION_IF_NULL(arg);
|
||||||
return arg->BuildType();
|
return arg->BuildType();
|
||||||
});
|
});
|
||||||
|
@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a Graph for this type signature.
|
// Generate a Graph for this type signature.
|
||||||
virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) {
|
virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) {
|
||||||
MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types.";
|
MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase {
|
||||||
std::string ToString() const override { return name_; }
|
std::string ToString() const override { return name_; }
|
||||||
std::size_t hash() const override { return tid(); }
|
std::size_t hash() const override { return tid(); }
|
||||||
|
|
||||||
virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; }
|
virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; }
|
||||||
bool operator==(const Value& other) const override {
|
bool operator==(const Value &other) const override {
|
||||||
if (other.isa<MetaFuncGraph>()) {
|
if (other.isa<MetaFuncGraph>()) {
|
||||||
return &other == this;
|
return &other == this;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace mindspore {
|
||||||
|
|
||||||
namespace tensor {
|
namespace tensor {
|
||||||
|
|
||||||
void DataBuf2Contiguous(const py::array& src, py::array* const dest) {
|
void DataBuf2Contiguous(const py::array &src, py::array *const dest) {
|
||||||
if (dest == nullptr) {
|
if (dest == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!";
|
MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!";
|
||||||
}
|
}
|
||||||
|
@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) {
|
||||||
// MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
|
// MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
|
||||||
MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
|
MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
|
||||||
|
|
||||||
MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int>& shape) : data_type_(data_type), shape_(shape) {}
|
MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int> &shape) : data_type_(data_type), shape_(shape) {}
|
||||||
|
|
||||||
MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) {
|
MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) {
|
||||||
TypeId data_type = TypeId::kTypeUnknown;
|
TypeId data_type = TypeId::kTypeUnknown;
|
||||||
if (type_ptr != nullptr) {
|
if (type_ptr != nullptr) {
|
||||||
data_type = type_ptr->type_id();
|
data_type = type_ptr->type_id();
|
||||||
|
@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MetaTensor::MetaTensor(const MetaTensor& meta_tensor)
|
MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
|
||||||
: Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
|
: Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
|
||||||
|
|
||||||
MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) {
|
MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
|
||||||
if (&meta_tensor == this) {
|
if (&meta_tensor == this) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MetaTensor::operator==(const MetaTensor& meta_tensor) const {
|
bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
|
||||||
return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
|
return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
|
||||||
return type_ptr;
|
return type_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) {
|
void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) {
|
||||||
DeviceInfo info(format, data_type);
|
DeviceInfo info(format, data_type);
|
||||||
set_device_info(info);
|
set_device_info(info);
|
||||||
}
|
}
|
||||||
|
@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) {
|
Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) {
|
||||||
TypeId data_type = TypeId::kTypeUnknown;
|
TypeId data_type = TypeId::kTypeUnknown;
|
||||||
if (type_ptr != nullptr) {
|
if (type_ptr != nullptr) {
|
||||||
data_type = type_ptr->type_id();
|
data_type = type_ptr->type_id();
|
||||||
|
@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) {
|
||||||
init(data_type_, shape_, &data_);
|
init(data_type_, shape_, &data_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor::Tensor(TypeId data_type, const std::vector<int>& shape) { init(data_type, shape, &data_); }
|
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { init(data_type, shape, &data_); }
|
||||||
|
|
||||||
Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); }
|
Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); }
|
||||||
|
|
||||||
Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); }
|
Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); }
|
||||||
|
|
||||||
Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); }
|
Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); }
|
||||||
|
|
||||||
Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); }
|
Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
|
||||||
|
|
||||||
Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); }
|
Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
|
||||||
|
|
||||||
Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type)
|
Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
|
||||||
: MetaTensor(tensor), device_address_(tensor.device_address()) {
|
: MetaTensor(tensor), device_address_(tensor.device_address()) {
|
||||||
init(tensor.data_, data_type);
|
init(tensor.data_, data_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& Tensor::operator=(const Tensor& tensor) {
|
Tensor &Tensor::operator=(const Tensor &tensor) {
|
||||||
if (this != &tensor) {
|
if (this != &tensor) {
|
||||||
MetaTensor::operator=(tensor);
|
MetaTensor::operator=(tensor);
|
||||||
dirty_ = tensor.is_dirty();
|
dirty_ = tensor.is_dirty();
|
||||||
|
@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Tensor::operator==(const Tensor& tensor) const {
|
bool Tensor::operator==(const Tensor &tensor) const {
|
||||||
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
|
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Tensor::ValueEqualPy(const py::object& other) const {
|
bool Tensor::ValueEqualPy(const py::object &other) const {
|
||||||
if (!py::isinstance<Tensor>(other)) {
|
if (!py::isinstance<Tensor>(other)) {
|
||||||
MS_LOG(WARNING) << "compare other not a tensor";
|
MS_LOG(WARNING) << "compare other not a tensor";
|
||||||
return false;
|
return false;
|
||||||
|
@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const {
|
||||||
return ValueEqual(py::cast<Tensor>(other));
|
return ValueEqual(py::cast<Tensor>(other));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Tensor::ValueEqual(const Tensor& other) const {
|
bool Tensor::ValueEqual(const Tensor &other) const {
|
||||||
auto equal = [&other, this]() -> bool {
|
auto equal = [&other, this]() -> bool {
|
||||||
auto np = py::module::import("numpy");
|
auto np = py::module::import("numpy");
|
||||||
auto equal = np.attr("equal")(data_, other.data_);
|
auto equal = np.attr("equal")(data_, other.data_);
|
||||||
|
@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast<int>(data_type_); }
|
||||||
|
|
||||||
std::vector<int> Tensor::shape_c(void) const { return shape(); }
|
std::vector<int> Tensor::shape_c(void) const { return shape(); }
|
||||||
|
|
||||||
void* Tensor::data_c(bool writable) {
|
void *Tensor::data_c(bool writable) {
|
||||||
// operand of bit operation should be unsigned int.
|
// operand of bit operation should be unsigned int.
|
||||||
unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_;
|
unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_;
|
||||||
bool is_c_contiguous = (flags != 0) ? true : false;
|
bool is_c_contiguous = (flags != 0) ? true : false;
|
||||||
|
@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) {
|
||||||
return data_.request(writable).ptr;
|
return data_.request(writable).ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeId Tensor::GetDataType(const py::buffer_info& buf) const {
|
TypeId Tensor::GetDataType(const py::buffer_info &buf) const {
|
||||||
TypeId data_type = TypeId::kTypeUnknown;
|
TypeId data_type = TypeId::kTypeUnknown;
|
||||||
if (buf.format.compare("e") == 0) {
|
if (buf.format.compare("e") == 0) {
|
||||||
data_type = TypeId::kNumberTypeFloat16;
|
data_type = TypeId::kNumberTypeFloat16;
|
||||||
|
@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const {
|
||||||
return data_type;
|
return data_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Tensor::init(const py::array& input, const TypePtr& type_ptr) {
|
void Tensor::init(const py::array &input, const TypePtr &type_ptr) {
|
||||||
TypeId data_type = TypeId::kTypeUnknown;
|
TypeId data_type = TypeId::kTypeUnknown;
|
||||||
if (type_ptr != nullptr) {
|
if (type_ptr != nullptr) {
|
||||||
data_type = type_ptr->type_id();
|
data_type = type_ptr->type_id();
|
||||||
|
@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) {
|
||||||
init(input, data_type);
|
init(input, data_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Tensor::init(const py::array& input, const TypeId& data_type) {
|
void Tensor::init(const py::array &input, const TypeId &data_type) {
|
||||||
py::buffer_info buf = input.request();
|
py::buffer_info buf = input.request();
|
||||||
|
|
||||||
data_type_ = GetDataType(buf);
|
data_type_ = GetDataType(buf);
|
||||||
|
@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Tensor::init(TypeId data_type, const std::vector<int>& shape, py::array* const data) {
|
void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
|
||||||
data_type_ = data_type;
|
data_type_ = data_type;
|
||||||
shape_ = shape;
|
shape_ = shape;
|
||||||
switch (data_type) {
|
switch (data_type) {
|
||||||
|
@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
|
||||||
return data_type_;
|
return data_type_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out,
|
bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out,
|
||||||
const TypeId out_data_type) {
|
const TypeId out_data_type) {
|
||||||
if (out == nullptr) {
|
if (out == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -458,7 +458,7 @@ py::array Tensor::data_sync() {
|
||||||
return data_;
|
return data_;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
||||||
// dtype should define before Tensor, because Tensor init depend dtype
|
// dtype should define before Tensor, because Tensor init depend dtype
|
||||||
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor")
|
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor")
|
||||||
.def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape"))
|
.def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape"))
|
||||||
|
@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) {
|
||||||
.def("__repr__", &Tensor::ToStringRepr)
|
.def("__repr__", &Tensor::ToStringRepr)
|
||||||
.def("__eq__", &Tensor::ValueEqualPy)
|
.def("__eq__", &Tensor::ValueEqualPy)
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const Tensor& t) { // __getstate__
|
[](const Tensor &t) { // __getstate__
|
||||||
/* Return a tuple that fully encodes the state of the object */
|
/* Return a tuple that fully encodes the state of the object */
|
||||||
return py::make_tuple(t.data());
|
return py::make_tuple(t.data());
|
||||||
},
|
},
|
||||||
[](const py::tuple& t) { // __setstate__
|
[](const py::tuple &t) { // __setstate__
|
||||||
if (t.size() != 1) {
|
if (t.size() != 1) {
|
||||||
throw std::runtime_error("Invalid state!");
|
throw std::runtime_error("Invalid state!");
|
||||||
}
|
}
|
||||||
|
|
|
@ -131,16 +131,16 @@ class MetaTensor : public Value {
|
||||||
// information of a Tensor. The following codes will create a 2x3 float
|
// information of a Tensor. The following codes will create a 2x3 float
|
||||||
// param data_type The data type of the tensor.
|
// param data_type The data type of the tensor.
|
||||||
// param shape The shape of the tensor.
|
// param shape The shape of the tensor.
|
||||||
MetaTensor(const TypeId data_type, const std::vector<int>& shape);
|
MetaTensor(const TypeId data_type, const std::vector<int> &shape);
|
||||||
|
|
||||||
MetaTensor(const TypePtr& type_ptr, const py::tuple& shape);
|
MetaTensor(const TypePtr &type_ptr, const py::tuple &shape);
|
||||||
// brief Constructs a MetaTensor object from an existing MetaTensor instance.
|
// brief Constructs a MetaTensor object from an existing MetaTensor instance.
|
||||||
//
|
//
|
||||||
// The constructed MetaTensor object will have the same data type and shape as the
|
// The constructed MetaTensor object will have the same data type and shape as the
|
||||||
// meta_tensor.
|
// meta_tensor.
|
||||||
//
|
//
|
||||||
// param meta_tensor An existing MetaTensor object.
|
// param meta_tensor An existing MetaTensor object.
|
||||||
MetaTensor(const MetaTensor& meta_tensor);
|
MetaTensor(const MetaTensor &meta_tensor);
|
||||||
~MetaTensor() override = default;
|
~MetaTensor() override = default;
|
||||||
MS_DECLARE_PARENT(MetaTensor, Value)
|
MS_DECLARE_PARENT(MetaTensor, Value)
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ class MetaTensor : public Value {
|
||||||
// The constructed MetaTensor object has the same type and shape with meta_tensor.
|
// The constructed MetaTensor object has the same type and shape with meta_tensor.
|
||||||
//
|
//
|
||||||
// param meta_tensor An existing MetaTensor object.
|
// param meta_tensor An existing MetaTensor object.
|
||||||
virtual MetaTensor& operator=(const MetaTensor& meta_tensor);
|
virtual MetaTensor &operator=(const MetaTensor &meta_tensor);
|
||||||
|
|
||||||
// brief Compares two MetaTensor objects.
|
// brief Compares two MetaTensor objects.
|
||||||
//
|
//
|
||||||
|
@ -157,7 +157,7 @@ class MetaTensor : public Value {
|
||||||
//
|
//
|
||||||
// param meta_tensor The MetaTensor object to be compared.
|
// param meta_tensor The MetaTensor object to be compared.
|
||||||
// return true: If having same type and shape, return true, or return false.
|
// return true: If having same type and shape, return true, or return false.
|
||||||
virtual bool operator==(const MetaTensor& meta_tensor) const;
|
virtual bool operator==(const MetaTensor &meta_tensor) const;
|
||||||
|
|
||||||
// brief Returns the data type of the tensor in its MetaTensor.
|
// brief Returns the data type of the tensor in its MetaTensor.
|
||||||
//
|
//
|
||||||
|
@ -193,7 +193,7 @@ class MetaTensor : public Value {
|
||||||
//
|
//
|
||||||
// param shape The shape of the tensor.
|
// param shape The shape of the tensor.
|
||||||
// return The shape's size.
|
// return The shape's size.
|
||||||
size_t set_shape(const std::vector<int>& shape) {
|
size_t set_shape(const std::vector<int> &shape) {
|
||||||
this->shape_ = shape;
|
this->shape_ = shape;
|
||||||
return shape_.size();
|
return shape_.size();
|
||||||
}
|
}
|
||||||
|
@ -202,9 +202,9 @@ class MetaTensor : public Value {
|
||||||
DeviceInfo device_info() const { return device_info_; }
|
DeviceInfo device_info() const { return device_info_; }
|
||||||
|
|
||||||
// Set tensor's device info.
|
// Set tensor's device info.
|
||||||
void set_device_info(const DeviceInfo& device_info) { device_info_ = device_info; }
|
void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; }
|
||||||
|
|
||||||
void SetDeviceInfo(const std::string& format, const TypePtr& data_type);
|
void SetDeviceInfo(const std::string &format, const TypePtr &data_type);
|
||||||
|
|
||||||
// Get the size of a given dimension by its index number.
|
// Get the size of a given dimension by its index number.
|
||||||
int DimensionSize(size_t index) const;
|
int DimensionSize(size_t index) const;
|
||||||
|
@ -222,9 +222,9 @@ class MetaTensor : public Value {
|
||||||
}
|
}
|
||||||
return hash_value;
|
return hash_value;
|
||||||
}
|
}
|
||||||
bool operator==(const Value& other) const override {
|
bool operator==(const Value &other) const override {
|
||||||
if (other.isa<MetaTensor>()) {
|
if (other.isa<MetaTensor>()) {
|
||||||
auto other_ = static_cast<const MetaTensor&>(other);
|
auto other_ = static_cast<const MetaTensor &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
@ -262,49 +262,49 @@ class Tensor : public MetaTensor {
|
||||||
//
|
//
|
||||||
// param type_ptr [TypePty] Data type of the tensor.
|
// param type_ptr [TypePty] Data type of the tensor.
|
||||||
// param py_shape [py::tuple] The shape represented by py::tuple of the tensor.
|
// param py_shape [py::tuple] The shape represented by py::tuple of the tensor.
|
||||||
Tensor(const TypePtr& type_ptr, const py::tuple& shape);
|
Tensor(const TypePtr &type_ptr, const py::tuple &shape);
|
||||||
|
|
||||||
// brief Constructor for C++.
|
// brief Constructor for C++.
|
||||||
//
|
//
|
||||||
// param data_type [TypeId] Data type of the tensor.
|
// param data_type [TypeId] Data type of the tensor.
|
||||||
// param shape The shape represented by std::vector<int> of the tensor.
|
// param shape The shape represented by std::vector<int> of the tensor.
|
||||||
Tensor(TypeId data_type, const std::vector<int>& shape);
|
Tensor(TypeId data_type, const std::vector<int> &shape);
|
||||||
|
|
||||||
// brief Constructor for Python.
|
// brief Constructor for Python.
|
||||||
//
|
//
|
||||||
// param input [py::array] Data value of the tensor.
|
// param input [py::array] Data value of the tensor.
|
||||||
// param data_type [TypeId] Data type of the tensor.
|
// param data_type [TypeId] Data type of the tensor.
|
||||||
explicit Tensor(const py::array& input, const TypePtr& data_type = nullptr);
|
explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
// brief Constructor
|
// brief Constructor
|
||||||
//
|
//
|
||||||
// param input [py::list] the data for tensor
|
// param input [py::list] the data for tensor
|
||||||
// param data_type [TypeId] data type
|
// param data_type [TypeId] data type
|
||||||
explicit Tensor(const py::list& input, const TypePtr& data_type = nullptr);
|
explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
// brief Constructor
|
// brief Constructor
|
||||||
//
|
//
|
||||||
// param input [py::tuple] the data for tensor
|
// param input [py::tuple] the data for tensor
|
||||||
// param data_type [TypeId] data type
|
// param data_type [TypeId] data type
|
||||||
explicit Tensor(const py::tuple& input, const TypePtr& data_type = nullptr);
|
explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
// brief Constructor
|
// brief Constructor
|
||||||
//
|
//
|
||||||
// param input [py::float_] the data for tensor
|
// param input [py::float_] the data for tensor
|
||||||
// param data_type [TypeId] data type
|
// param data_type [TypeId] data type
|
||||||
explicit Tensor(const py::float_& input, const TypePtr& data_type = nullptr);
|
explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
// brief Constructor
|
// brief Constructor
|
||||||
//
|
//
|
||||||
// param input [py::int_] the data for tensor
|
// param input [py::int_] the data for tensor
|
||||||
// param data_type [TypeId] data type
|
// param data_type [TypeId] data type
|
||||||
explicit Tensor(const py::int_& input, const TypePtr& data_type = nullptr);
|
explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
// brief Constructor
|
// brief Constructor
|
||||||
//
|
//
|
||||||
// param input [Tensor] the data for tensor
|
// param input [Tensor] the data for tensor
|
||||||
// param data_type [TypeId] data type
|
// param data_type [TypeId] data type
|
||||||
Tensor(const Tensor& tensor, const TypePtr& data_type = nullptr);
|
Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
~Tensor() override = default;
|
~Tensor() override = default;
|
||||||
|
|
||||||
|
@ -315,7 +315,7 @@ class Tensor : public MetaTensor {
|
||||||
// The constructed Tensor object has the same type and shape with tensor.
|
// The constructed Tensor object has the same type and shape with tensor.
|
||||||
//
|
//
|
||||||
// param tensor An existing Tensor object.
|
// param tensor An existing Tensor object.
|
||||||
Tensor& operator=(const Tensor& tensor);
|
Tensor &operator=(const Tensor &tensor);
|
||||||
|
|
||||||
// brief Compares two Tensor objects.
|
// brief Compares two Tensor objects.
|
||||||
//
|
//
|
||||||
|
@ -324,17 +324,17 @@ class Tensor : public MetaTensor {
|
||||||
//
|
//
|
||||||
// param tensor The Tensor object to be compared.
|
// param tensor The Tensor object to be compared.
|
||||||
// return true: If having same type, shape and data, return true, or return false.
|
// return true: If having same type, shape and data, return true, or return false.
|
||||||
bool operator==(const Tensor& tensor) const;
|
bool operator==(const Tensor &tensor) const;
|
||||||
|
|
||||||
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
||||||
bool ValueEqual(const Tensor& other) const;
|
bool ValueEqual(const Tensor &other) const;
|
||||||
|
|
||||||
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
||||||
bool ValueEqualPy(const py::object& other) const;
|
bool ValueEqualPy(const py::object &other) const;
|
||||||
|
|
||||||
bool operator==(const Value& other) const override {
|
bool operator==(const Value &other) const override {
|
||||||
if (other.isa<Tensor>()) {
|
if (other.isa<Tensor>()) {
|
||||||
auto other_ = static_cast<const Tensor&>(other);
|
auto other_ = static_cast<const Tensor &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
@ -375,13 +375,13 @@ class Tensor : public MetaTensor {
|
||||||
//
|
//
|
||||||
// param writable true if writable, false if read only
|
// param writable true if writable, false if read only
|
||||||
// return The pointer to the object
|
// return The pointer to the object
|
||||||
void* data_c(bool writable = false);
|
void *data_c(bool writable = false);
|
||||||
|
|
||||||
// brief Get data type from tensor data.
|
// brief Get data type from tensor data.
|
||||||
//
|
//
|
||||||
// param buf The buffer info of the py::array data.
|
// param buf The buffer info of the py::array data.
|
||||||
// return The [TypeId] of the tensor data.
|
// return The [TypeId] of the tensor data.
|
||||||
TypeId GetDataType(const py::buffer_info& buf) const;
|
TypeId GetDataType(const py::buffer_info &buf) const;
|
||||||
|
|
||||||
// brief Sets the data type of a tensor.
|
// brief Sets the data type of a tensor.
|
||||||
//
|
//
|
||||||
|
@ -401,23 +401,23 @@ class Tensor : public MetaTensor {
|
||||||
// param input [py::array] the data for tensor
|
// param input [py::array] the data for tensor
|
||||||
// param data_type [TypeId] data type
|
// param data_type [TypeId] data type
|
||||||
// return true if succeed, false if failed.
|
// return true if succeed, false if failed.
|
||||||
void init(const py::array& input, const TypeId& data_type);
|
void init(const py::array &input, const TypeId &data_type);
|
||||||
void init(const py::array& input, const TypePtr& type_ptr);
|
void init(const py::array &input, const TypePtr &type_ptr);
|
||||||
|
|
||||||
// brief init tensor attribute
|
// brief init tensor attribute
|
||||||
//
|
//
|
||||||
// param data_type [TypeId] Data type of the tensor.
|
// param data_type [TypeId] Data type of the tensor.
|
||||||
// param shape [py::array] The shape of the tensor.
|
// param shape [py::array] The shape of the tensor.
|
||||||
// return true if succeed, false if failed.
|
// return true if succeed, false if failed.
|
||||||
void init(TypeId data_type, const std::vector<int>& shape, py::array* data);
|
void init(TypeId data_type, const std::vector<int> &shape, py::array *data);
|
||||||
|
|
||||||
bool convert_data(const py::array& in, const TypeId in_data_type, py::array* out, const TypeId out_data_type);
|
bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
bool is_dirty() const { return dirty_; }
|
bool is_dirty() const { return dirty_; }
|
||||||
void set_dirty(const bool dirty) { dirty_ = dirty; }
|
void set_dirty(const bool dirty) { dirty_ = dirty; }
|
||||||
DeviceAddressPtr device_address() const { return device_address_; }
|
DeviceAddressPtr device_address() const { return device_address_; }
|
||||||
void set_device_address(const DeviceAddressPtr& device_address) { device_address_ = device_address; }
|
void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
|
||||||
py::array data_sync();
|
py::array data_sync();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
#include "pipeline/static_analysis/abstract_value.h"
|
#include "pipeline/static_analysis/abstract_value.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
bool Named::operator==(const Value& other) const {
|
bool Named::operator==(const Value &other) const {
|
||||||
if (other.isa<Named>()) {
|
if (other.isa<Named>()) {
|
||||||
auto other_named = static_cast<const Named&>(other);
|
auto other_named = static_cast<const Named &>(other);
|
||||||
return *this == other_named;
|
return *this == other_named;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -27,18 +27,18 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class Named : public Value {
|
class Named : public Value {
|
||||||
public:
|
public:
|
||||||
explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
|
explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
|
||||||
Named(const Named& other) : Value(other) {
|
Named(const Named &other) : Value(other) {
|
||||||
this->name_ = other.name_;
|
this->name_ = other.name_;
|
||||||
hash_id_ = std::hash<std::string>{}(other.name_);
|
hash_id_ = std::hash<std::string>{}(other.name_);
|
||||||
}
|
}
|
||||||
~Named() override = default;
|
~Named() override = default;
|
||||||
MS_DECLARE_PARENT(Named, Value);
|
MS_DECLARE_PARENT(Named, Value);
|
||||||
|
|
||||||
const std::string& name() const { return name_; }
|
const std::string &name() const { return name_; }
|
||||||
virtual bool operator==(const Named& other) const { return name_ == other.name(); }
|
virtual bool operator==(const Named &other) const { return name_ == other.name(); }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
Named& operator=(const Named& other) {
|
Named &operator=(const Named &other) {
|
||||||
if (&other != this) {
|
if (&other != this) {
|
||||||
this->type_ = other.type_;
|
this->type_ = other.type_;
|
||||||
this->name_ = other.name_;
|
this->name_ = other.name_;
|
||||||
|
@ -50,7 +50,7 @@ class Named : public Value {
|
||||||
std::size_t Hash() const { return hash_id_; }
|
std::size_t Hash() const { return hash_id_; }
|
||||||
std::size_t hash() const override { return hash_id_; }
|
std::size_t hash() const override { return hash_id_; }
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const Named& nmd) {
|
friend std::ostream &operator<<(std::ostream &os, const Named &nmd) {
|
||||||
os << nmd.name();
|
os << nmd.name();
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
using mindspore::abstract::AbstractFunction;
|
using mindspore::abstract::AbstractFunction;
|
||||||
|
|
||||||
abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr& anf_node) {
|
abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) {
|
||||||
auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node);
|
auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node);
|
||||||
return prim_func;
|
return prim_func;
|
||||||
}
|
}
|
||||||
|
@ -63,23 +63,23 @@ py::function Primitive::GetComputeFunction() {
|
||||||
return fn;
|
return fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Primitive::operator==(const Value& other) const {
|
bool Primitive::operator==(const Value &other) const {
|
||||||
if (other.isa<Primitive>()) {
|
if (other.isa<Primitive>()) {
|
||||||
auto other_prim = static_cast<const Primitive&>(other);
|
auto other_prim = static_cast<const Primitive &>(other);
|
||||||
return *this == other_prim;
|
return *this == other_prim;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Primitive::operator==(const Primitive& other) const {
|
bool Primitive::operator==(const Primitive &other) const {
|
||||||
if (name() != other.name()) {
|
if (name() != other.name()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (attrs_.size() != other.attrs_.size()) {
|
if (attrs_.size() != other.attrs_.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr>& item) -> bool {
|
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
|
||||||
if (item.second == nullptr) {
|
if (item.second == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -95,7 +95,7 @@ bool Primitive::operator==(const Primitive& other) const {
|
||||||
void Primitive::set_signatures(
|
void Primitive::set_signatures(
|
||||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
|
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
|
||||||
signatures_.clear();
|
signatures_.clear();
|
||||||
for (auto& signature : signatures) {
|
for (auto &signature : signatures) {
|
||||||
std::string name;
|
std::string name;
|
||||||
SignatureEnumRW rw;
|
SignatureEnumRW rw;
|
||||||
SignatureEnumKind kind;
|
SignatureEnumKind kind;
|
||||||
|
@ -114,7 +114,7 @@ std::string Primitive::GetAttrsText() const {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "[";
|
oss << "[";
|
||||||
bool is_first = true;
|
bool is_first = true;
|
||||||
for (auto& attr : attrs_) {
|
for (auto &attr : attrs_) {
|
||||||
if (is_first) {
|
if (is_first) {
|
||||||
is_first = false;
|
is_first = false;
|
||||||
} else {
|
} else {
|
||||||
|
@ -128,7 +128,7 @@ std::string Primitive::GetAttrsText() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
py::function PrimitivePy::GetBpropFunction() {
|
py::function PrimitivePy::GetBpropFunction() {
|
||||||
static const char* const get_bprop_func_name = "get_bprop";
|
static const char *const get_bprop_func_name = "get_bprop";
|
||||||
if (py::hasattr(python_obj_, get_bprop_func_name)) {
|
if (py::hasattr(python_obj_, get_bprop_func_name)) {
|
||||||
py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
|
py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
|
||||||
return fn;
|
return fn;
|
||||||
|
@ -142,7 +142,7 @@ py::function PrimitivePy::GetBpropFunction() {
|
||||||
}
|
}
|
||||||
|
|
||||||
py::function PrimitivePy::GetComputeFunction() {
|
py::function PrimitivePy::GetComputeFunction() {
|
||||||
static const char* const compute_func_name = "vm_impl";
|
static const char *const compute_func_name = "vm_impl";
|
||||||
|
|
||||||
if (py::hasattr(python_obj_, compute_func_name)) {
|
if (py::hasattr(python_obj_, compute_func_name)) {
|
||||||
MS_LOG(INFO) << "" << name() << " compute_func_name";
|
MS_LOG(INFO) << "" << name() << " compute_func_name";
|
||||||
|
@ -163,7 +163,7 @@ py::function PrimitivePy::GetComputeFunction() {
|
||||||
return vm_fn;
|
return vm_fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) {
|
void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
|
||||||
std::string attr_name = name;
|
std::string attr_name = name;
|
||||||
ValuePtr converted_ret = nullptr;
|
ValuePtr converted_ret = nullptr;
|
||||||
if (py::isinstance<py::module>(obj)) {
|
if (py::isinstance<py::module>(obj)) {
|
||||||
|
@ -178,13 +178,13 @@ void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) {
|
||||||
|
|
||||||
py::dict PrimitivePy::GetAttrDict() {
|
py::dict PrimitivePy::GetAttrDict() {
|
||||||
py::dict attr_dict;
|
py::dict attr_dict;
|
||||||
for (auto& attr : attrs_) {
|
for (auto &attr : attrs_) {
|
||||||
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
|
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
|
||||||
}
|
}
|
||||||
return attr_dict;
|
return attr_dict;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
||||||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||||
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
||||||
|
@ -192,7 +192,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) {
|
||||||
.value("user_custom", PrimType::kPrimTypeUserCustom);
|
.value("user_custom", PrimType::kPrimTypeUserCustom);
|
||||||
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
|
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
|
||||||
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
|
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
|
||||||
.def(py::init<py::str&, py::object>())
|
.def(py::init<py::str &, py::object>())
|
||||||
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
|
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
|
||||||
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
|
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
|
||||||
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
|
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
|
||||||
|
|
|
@ -48,25 +48,25 @@ enum PrimType {
|
||||||
|
|
||||||
class Primitive : public Named {
|
class Primitive : public Named {
|
||||||
public:
|
public:
|
||||||
explicit Primitive(const std::string& name, const PrimType prim_type = kPrimTypeBuiltIn)
|
explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn)
|
||||||
: Named(name), signatures_(), prim_type_(prim_type) {}
|
: Named(name), signatures_(), prim_type_(prim_type) {}
|
||||||
|
|
||||||
Primitive(const Primitive& prim)
|
Primitive(const Primitive &prim)
|
||||||
: Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {}
|
: Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {}
|
||||||
|
|
||||||
MS_DECLARE_PARENT(Primitive, Named);
|
MS_DECLARE_PARENT(Primitive, Named);
|
||||||
|
|
||||||
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr& anf_node);
|
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||||
std::string ToString() const override { return name(); }
|
std::string ToString() const override { return name(); }
|
||||||
virtual py::function GetBpropFunction();
|
virtual py::function GetBpropFunction();
|
||||||
virtual py::function GetComputeFunction();
|
virtual py::function GetComputeFunction();
|
||||||
Primitive& AddAttr(const std::string& name, const ValuePtr& attr) {
|
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
|
||||||
attrs_[name] = attr;
|
attrs_[name] = attr;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Primitive& SetAttrs(const std::unordered_map<std::string, ValuePtr>& attrs) {
|
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||||
for (auto& attr : attrs) {
|
for (auto &attr : attrs) {
|
||||||
attrs_[attr.first] = attr.second;
|
attrs_[attr.first] = attr.second;
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
|
@ -76,21 +76,21 @@ class Primitive : public Named {
|
||||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
|
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
|
||||||
signatures);
|
signatures);
|
||||||
|
|
||||||
const std::vector<Signature>& signatures() const { return signatures_; }
|
const std::vector<Signature> &signatures() const { return signatures_; }
|
||||||
|
|
||||||
void set_attr(const std::string& attrName, const ValuePtr& attr) { attrs_[attrName] = attr; }
|
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
|
||||||
void EraseAttr(const std::string& attrName) { (void)attrs_.erase(attrName); }
|
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
|
||||||
|
|
||||||
ValuePtr GetAttr(const std::string& attrName) const {
|
ValuePtr GetAttr(const std::string &attrName) const {
|
||||||
auto iter = attrs_.find(attrName);
|
auto iter = attrs_.find(attrName);
|
||||||
return iter == attrs_.cend() ? nullptr : iter->second;
|
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_map<std::string, ValuePtr>& attrs() const { return attrs_; }
|
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||||
|
|
||||||
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
||||||
bool HasAttr() const { return !attrs_.empty(); }
|
bool HasAttr() const { return !attrs_.empty(); }
|
||||||
bool HasAttr(const std::string& attrName) const {
|
bool HasAttr(const std::string &attrName) const {
|
||||||
auto iter = attrs_.find(attrName);
|
auto iter = attrs_.find(attrName);
|
||||||
return !(iter == attrs_.cend());
|
return !(iter == attrs_.cend());
|
||||||
}
|
}
|
||||||
|
@ -103,8 +103,8 @@ class Primitive : public Named {
|
||||||
PrimType prim_type() const { return prim_type_; }
|
PrimType prim_type() const { return prim_type_; }
|
||||||
std::string instance_name() const { return instance_name_; }
|
std::string instance_name() const { return instance_name_; }
|
||||||
std::string GetAttrsText() const;
|
std::string GetAttrsText() const;
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const Primitive& other) const;
|
bool operator==(const Primitive &other) const;
|
||||||
~Primitive() override = default;
|
~Primitive() override = default;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -118,18 +118,18 @@ class Primitive : public Named {
|
||||||
|
|
||||||
class PrimitivePy : public Primitive {
|
class PrimitivePy : public Primitive {
|
||||||
public:
|
public:
|
||||||
PrimitivePy(const py::str& name, const py::object& python_obj) : Primitive(name), python_obj_(python_obj) {}
|
PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {}
|
||||||
~PrimitivePy() override = default;
|
~PrimitivePy() override = default;
|
||||||
MS_DECLARE_PARENT(PrimitivePy, Primitive);
|
MS_DECLARE_PARENT(PrimitivePy, Primitive);
|
||||||
py::function GetBpropFunction() override;
|
py::function GetBpropFunction() override;
|
||||||
py::function GetComputeFunction() override;
|
py::function GetComputeFunction() override;
|
||||||
|
|
||||||
void AddPyAttr(const py::str& name, const py::object& obj);
|
void AddPyAttr(const py::str &name, const py::object &obj);
|
||||||
|
|
||||||
py::dict GetAttrDict();
|
py::dict GetAttrDict();
|
||||||
|
|
||||||
const bool parse_info_ = true;
|
const bool parse_info_ = true;
|
||||||
const py::object& GetPyObj() const { return python_obj_; }
|
const py::object &GetPyObj() const { return python_obj_; }
|
||||||
bool is_tuple_input_ = false;
|
bool is_tuple_input_ = false;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -138,13 +138,13 @@ class PrimitivePy : public Primitive {
|
||||||
|
|
||||||
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
|
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& os, const PrimitivePtr& p) {
|
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
||||||
os << *p;
|
os << *p;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PrimitiveEqual {
|
struct PrimitiveEqual {
|
||||||
bool operator()(PrimitivePtr const& t1, PrimitivePtr const& t2) const {
|
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||||
MS_EXCEPTION_IF_NULL(t1);
|
MS_EXCEPTION_IF_NULL(t1);
|
||||||
MS_EXCEPTION_IF_NULL(t2);
|
MS_EXCEPTION_IF_NULL(t2);
|
||||||
return t1->name() == t2->name();
|
return t1->name() == t2->name();
|
||||||
|
@ -152,7 +152,7 @@ struct PrimitiveEqual {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PrimitiveHasher {
|
struct PrimitiveHasher {
|
||||||
std::size_t operator()(PrimitivePtr const& prim) const {
|
std::size_t operator()(PrimitivePtr const &prim) const {
|
||||||
std::size_t hash = std::hash<std::string>()(prim->name());
|
std::size_t hash = std::hash<std::string>()(prim->name());
|
||||||
return hash;
|
return hash;
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,8 +55,8 @@ class BoolImm : public Scalar {
|
||||||
bool value() const { return v_; }
|
bool value() const { return v_; }
|
||||||
bool IsZero() override { return v_ == false; }
|
bool IsZero() override { return v_ == false; }
|
||||||
bool IsOne() override { return v_ == true; }
|
bool IsOne() override { return v_ == true; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const BoolImm& other) const;
|
bool operator==(const BoolImm &other) const;
|
||||||
std::string ToString() const override {
|
std::string ToString() const override {
|
||||||
if (v_) {
|
if (v_) {
|
||||||
return "true";
|
return "true";
|
||||||
|
@ -80,7 +80,7 @@ IMM_TRAITS(BoolImmPtr, bool)
|
||||||
class IntergerImm : public Scalar {
|
class IntergerImm : public Scalar {
|
||||||
public:
|
public:
|
||||||
IntergerImm() = default;
|
IntergerImm() = default;
|
||||||
explicit IntergerImm(const TypePtr& t) : Scalar(t) {}
|
explicit IntergerImm(const TypePtr &t) : Scalar(t) {}
|
||||||
~IntergerImm() override = default;
|
~IntergerImm() override = default;
|
||||||
MS_DECLARE_PARENT(IntergerImm, Scalar)
|
MS_DECLARE_PARENT(IntergerImm, Scalar)
|
||||||
};
|
};
|
||||||
|
@ -95,8 +95,8 @@ class Int8Imm : public IntergerImm {
|
||||||
bool IsZero() override { return v_ == 0; }
|
bool IsZero() override { return v_ == 0; }
|
||||||
bool IsOne() override { return v_ == 1; }
|
bool IsOne() override { return v_ == 1; }
|
||||||
int8_t value() const { return v_; }
|
int8_t value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const Int8Imm& other) const;
|
bool operator==(const Int8Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -121,8 +121,8 @@ class Int16Imm : public IntergerImm {
|
||||||
bool IsZero() override { return v_ == 0; }
|
bool IsZero() override { return v_ == 0; }
|
||||||
bool IsOne() override { return v_ == 1; }
|
bool IsOne() override { return v_ == 1; }
|
||||||
int16_t value() const { return v_; }
|
int16_t value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const Int16Imm& other) const;
|
bool operator==(const Int16Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -147,8 +147,8 @@ class Int32Imm : public IntergerImm {
|
||||||
bool IsZero() override { return v_ == 0; }
|
bool IsZero() override { return v_ == 0; }
|
||||||
bool IsOne() override { return v_ == 1; }
|
bool IsOne() override { return v_ == 1; }
|
||||||
int32_t value() const { return v_; }
|
int32_t value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const Int32Imm& other) const;
|
bool operator==(const Int32Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -173,8 +173,8 @@ class Int64Imm : public IntergerImm {
|
||||||
bool IsZero() override { return v_ == 0; }
|
bool IsZero() override { return v_ == 0; }
|
||||||
bool IsOne() override { return v_ == 1; }
|
bool IsOne() override { return v_ == 1; }
|
||||||
int64_t value() const { return v_; }
|
int64_t value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const Int64Imm& other) const;
|
bool operator==(const Int64Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -199,8 +199,8 @@ class UInt8Imm : public IntergerImm {
|
||||||
bool IsZero() override { return v_ == 0; }
|
bool IsZero() override { return v_ == 0; }
|
||||||
bool IsOne() override { return v_ == 1; }
|
bool IsOne() override { return v_ == 1; }
|
||||||
uint8_t value() const { return v_; }
|
uint8_t value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const UInt8Imm& other) const;
|
bool operator==(const UInt8Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -225,8 +225,8 @@ class UInt16Imm : public IntergerImm {
|
||||||
bool IsZero() override { return v_ == 0; }
|
bool IsZero() override { return v_ == 0; }
|
||||||
bool IsOne() override { return v_ == 1; }
|
bool IsOne() override { return v_ == 1; }
|
||||||
uint16_t value() const { return v_; }
|
uint16_t value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const UInt16Imm& other) const;
|
bool operator==(const UInt16Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -251,8 +251,8 @@ class UInt32Imm : public IntergerImm {
|
||||||
bool IsZero() override { return v_ == 0; }
|
bool IsZero() override { return v_ == 0; }
|
||||||
bool IsOne() override { return v_ == 1; }
|
bool IsOne() override { return v_ == 1; }
|
||||||
uint32_t value() const { return v_; }
|
uint32_t value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const UInt32Imm& other) const;
|
bool operator==(const UInt32Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -277,8 +277,8 @@ class UInt64Imm : public IntergerImm {
|
||||||
bool IsZero() override { return v_ == 0; }
|
bool IsZero() override { return v_ == 0; }
|
||||||
bool IsOne() override { return v_ == 1; }
|
bool IsOne() override { return v_ == 1; }
|
||||||
uint64_t value() const { return v_; }
|
uint64_t value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const UInt64Imm& other) const;
|
bool operator==(const UInt64Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -296,7 +296,7 @@ IMM_TRAITS(UInt64ImmPtr, uint64_t);
|
||||||
class FloatImm : public Scalar {
|
class FloatImm : public Scalar {
|
||||||
public:
|
public:
|
||||||
FloatImm() = default;
|
FloatImm() = default;
|
||||||
explicit FloatImm(const TypePtr& t) : Scalar(t) {}
|
explicit FloatImm(const TypePtr &t) : Scalar(t) {}
|
||||||
~FloatImm() override = default;
|
~FloatImm() override = default;
|
||||||
MS_DECLARE_PARENT(FloatImm, Scalar)
|
MS_DECLARE_PARENT(FloatImm, Scalar)
|
||||||
};
|
};
|
||||||
|
@ -312,8 +312,8 @@ class FP32Imm : public FloatImm {
|
||||||
bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
|
bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
|
||||||
bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
|
bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
|
||||||
float value() const { return v_; }
|
float value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const FP32Imm& other) const;
|
bool operator==(const FP32Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
@ -338,8 +338,8 @@ class FP64Imm : public FloatImm {
|
||||||
bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
|
bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
|
||||||
bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
|
bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
|
||||||
double value() const { return v_; }
|
double value() const { return v_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const FP64Imm& other) const;
|
bool operator==(const FP64Imm &other) const;
|
||||||
std::string ToString() const override { return std::to_string(v_); }
|
std::string ToString() const override { return std::to_string(v_); }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
#include "pipeline/parse/data_converter.h"
|
#include "pipeline/parse/data_converter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind,
|
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
|
||||||
const py::object& arg_default, const SignatureEnumDType& arg_dtype)
|
const py::object &arg_default, const SignatureEnumDType &arg_dtype)
|
||||||
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
|
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
|
||||||
if (py::isinstance<SignatureEnumKind>(arg_default) &&
|
if (py::isinstance<SignatureEnumKind>(arg_default) &&
|
||||||
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
|
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
|
||||||
|
@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind)
|
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
|
||||||
: name(arg_name),
|
: name(arg_name),
|
||||||
rw(rw_tag),
|
rw(rw_tag),
|
||||||
kind(arg_kind),
|
kind(arg_kind),
|
||||||
default_value(nullptr),
|
default_value(nullptr),
|
||||||
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
|
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
|
||||||
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
|
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
|
||||||
.value("RW_READ", SignatureEnumRW::kRWRead)
|
.value("RW_READ", SignatureEnumRW::kRWRead)
|
||||||
.value("RW_WRITE", SignatureEnumRW::kRWWrite)
|
.value("RW_WRITE", SignatureEnumRW::kRWWrite)
|
||||||
|
|
|
@ -61,9 +61,9 @@ struct Signature {
|
||||||
SignatureEnumKind kind;
|
SignatureEnumKind kind;
|
||||||
ValuePtr default_value; // nullptr for no default value
|
ValuePtr default_value; // nullptr for no default value
|
||||||
SignatureEnumDType dtype;
|
SignatureEnumDType dtype;
|
||||||
Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind,
|
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
|
||||||
const py::object& arg_default, const SignatureEnumDType& arg_dtype);
|
const py::object &arg_default, const SignatureEnumDType &arg_dtype);
|
||||||
Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind);
|
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind);
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include "pipeline/static_analysis/abstract_value.h"
|
#include "pipeline/static_analysis/abstract_value.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
const ValuePtr ValueSequeue::operator[](const std::size_t& dim) const {
|
const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const {
|
||||||
if (dim >= size()) {
|
if (dim >= size()) {
|
||||||
MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "].";
|
MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "].";
|
||||||
}
|
}
|
||||||
|
@ -40,125 +40,125 @@ bool ValueSequeue::erase(size_t idx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BoolImm::operator==(const Value& other) const {
|
bool BoolImm::operator==(const Value &other) const {
|
||||||
if (other.isa<BoolImm>()) {
|
if (other.isa<BoolImm>()) {
|
||||||
auto other_ = static_cast<const BoolImm&>(other);
|
auto other_ = static_cast<const BoolImm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool BoolImm::operator==(const BoolImm& other) const { return v_ == other.v_; }
|
bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; }
|
||||||
|
|
||||||
bool Int8Imm::operator==(const Value& other) const {
|
bool Int8Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<Int8Imm>()) {
|
if (other.isa<Int8Imm>()) {
|
||||||
auto other_ = static_cast<const Int8Imm&>(other);
|
auto other_ = static_cast<const Int8Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool Int8Imm::operator==(const Int8Imm& other) const { return v_ == other.v_; }
|
bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; }
|
||||||
bool Int16Imm::operator==(const Value& other) const {
|
bool Int16Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<Int16Imm>()) {
|
if (other.isa<Int16Imm>()) {
|
||||||
auto other_ = static_cast<const Int16Imm&>(other);
|
auto other_ = static_cast<const Int16Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool Int16Imm::operator==(const Int16Imm& other) const { return v_ == other.v_; }
|
bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; }
|
||||||
bool Int32Imm::operator==(const Value& other) const {
|
bool Int32Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<Int32Imm>()) {
|
if (other.isa<Int32Imm>()) {
|
||||||
auto other_ = static_cast<const Int32Imm&>(other);
|
auto other_ = static_cast<const Int32Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool Int32Imm::operator==(const Int32Imm& other) const { return v_ == other.v_; }
|
bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; }
|
||||||
bool Int64Imm::operator==(const Value& other) const {
|
bool Int64Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<Int64Imm>()) {
|
if (other.isa<Int64Imm>()) {
|
||||||
auto other_ = static_cast<const Int64Imm&>(other);
|
auto other_ = static_cast<const Int64Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool Int64Imm::operator==(const Int64Imm& other) const { return v_ == other.v_; }
|
bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; }
|
||||||
bool UInt8Imm::operator==(const Value& other) const {
|
bool UInt8Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<UInt8Imm>()) {
|
if (other.isa<UInt8Imm>()) {
|
||||||
auto other_ = static_cast<const UInt8Imm&>(other);
|
auto other_ = static_cast<const UInt8Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool UInt8Imm::operator==(const UInt8Imm& other) const { return v_ == other.v_; }
|
bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; }
|
||||||
bool UInt16Imm::operator==(const Value& other) const {
|
bool UInt16Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<UInt16Imm>()) {
|
if (other.isa<UInt16Imm>()) {
|
||||||
auto other_ = static_cast<const UInt16Imm&>(other);
|
auto other_ = static_cast<const UInt16Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool UInt16Imm::operator==(const UInt16Imm& other) const { return v_ == other.v_; }
|
bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; }
|
||||||
bool UInt32Imm::operator==(const Value& other) const {
|
bool UInt32Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<UInt32Imm>()) {
|
if (other.isa<UInt32Imm>()) {
|
||||||
auto other_ = static_cast<const UInt32Imm&>(other);
|
auto other_ = static_cast<const UInt32Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool UInt32Imm::operator==(const UInt32Imm& other) const { return v_ == other.v_; }
|
bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; }
|
||||||
bool UInt64Imm::operator==(const Value& other) const {
|
bool UInt64Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<UInt64Imm>()) {
|
if (other.isa<UInt64Imm>()) {
|
||||||
auto other_ = static_cast<const UInt64Imm&>(other);
|
auto other_ = static_cast<const UInt64Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool UInt64Imm::operator==(const UInt64Imm& other) const { return v_ == other.v_; }
|
bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; }
|
||||||
bool FP32Imm::operator==(const Value& other) const {
|
bool FP32Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<FP32Imm>()) {
|
if (other.isa<FP32Imm>()) {
|
||||||
auto other_ = static_cast<const FP32Imm&>(other);
|
auto other_ = static_cast<const FP32Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool FP32Imm::operator==(const FP32Imm& other) const { return fabs(v_ - other.v_) < FLT_EPSILON; }
|
bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; }
|
||||||
bool FP64Imm::operator==(const Value& other) const {
|
bool FP64Imm::operator==(const Value &other) const {
|
||||||
if (other.isa<FP64Imm>()) {
|
if (other.isa<FP64Imm>()) {
|
||||||
auto other_ = static_cast<const FP64Imm&>(other);
|
auto other_ = static_cast<const FP64Imm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool ValueSequeue::operator==(const Value& other) const {
|
bool ValueSequeue::operator==(const Value &other) const {
|
||||||
if (other.isa<ValueSequeue>()) {
|
if (other.isa<ValueSequeue>()) {
|
||||||
auto other_ = static_cast<const ValueSequeue&>(other);
|
auto other_ = static_cast<const ValueSequeue &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool ValueSequeue::operator==(const ValueSequeue& other) const {
|
bool ValueSequeue::operator==(const ValueSequeue &other) const {
|
||||||
if (other.elements_.size() != elements_.size()) {
|
if (other.elements_.size() != elements_.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(),
|
return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(),
|
||||||
[](const ValuePtr& lhs, const ValuePtr& rhs) { return *lhs == *rhs; });
|
[](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; });
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ValueSequeue::ToString() const {
|
std::string ValueSequeue::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
bool begin = true;
|
bool begin = true;
|
||||||
for (auto& attr : elements_) {
|
for (auto &attr : elements_) {
|
||||||
if (!begin) {
|
if (!begin) {
|
||||||
buffer << ", ";
|
buffer << ", ";
|
||||||
} else {
|
} else {
|
||||||
|
@ -179,28 +179,28 @@ std::string ValueSequeue::DumpText() const {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FP64Imm::operator==(const FP64Imm& other) const { return fabs(v_ - other.v_) < DBL_EPSILON; }
|
bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; }
|
||||||
bool StringImm::operator==(const Value& other) const {
|
bool StringImm::operator==(const Value &other) const {
|
||||||
if (other.isa<StringImm>()) {
|
if (other.isa<StringImm>()) {
|
||||||
auto other_ = static_cast<const StringImm&>(other);
|
auto other_ = static_cast<const StringImm &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool StringImm::operator==(const StringImm& other) const { return str_ == other.str_; }
|
bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; }
|
||||||
|
|
||||||
bool RefKey::operator==(const Value& other) const {
|
bool RefKey::operator==(const Value &other) const {
|
||||||
if (other.isa<RefKey>()) {
|
if (other.isa<RefKey>()) {
|
||||||
auto other_ = static_cast<const RefKey&>(other);
|
auto other_ = static_cast<const RefKey &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool RefKey::operator==(const RefKey& other) const { return tag_ == other.tag_; }
|
bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; }
|
||||||
|
|
||||||
bool AnyValue::operator==(const Value& other) const {
|
bool AnyValue::operator==(const Value &other) const {
|
||||||
if (other.isa<AnyValue>()) {
|
if (other.isa<AnyValue>()) {
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
|
@ -228,7 +228,7 @@ abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstr
|
||||||
|
|
||||||
abstract::AbstractBasePtr ValueTuple::ToAbstract() {
|
abstract::AbstractBasePtr ValueTuple::ToAbstract() {
|
||||||
abstract::AbstractBasePtrList a_list;
|
abstract::AbstractBasePtrList a_list;
|
||||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) {
|
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
|
||||||
MS_EXCEPTION_IF_NULL(ele);
|
MS_EXCEPTION_IF_NULL(ele);
|
||||||
return ele->ToAbstract();
|
return ele->ToAbstract();
|
||||||
});
|
});
|
||||||
|
@ -237,7 +237,7 @@ abstract::AbstractBasePtr ValueTuple::ToAbstract() {
|
||||||
|
|
||||||
abstract::AbstractBasePtr ValueList::ToAbstract() {
|
abstract::AbstractBasePtr ValueList::ToAbstract() {
|
||||||
abstract::AbstractBasePtrList a_list;
|
abstract::AbstractBasePtrList a_list;
|
||||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) {
|
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
|
||||||
MS_EXCEPTION_IF_NULL(ele);
|
MS_EXCEPTION_IF_NULL(ele);
|
||||||
return ele->ToAbstract();
|
return ele->ToAbstract();
|
||||||
});
|
});
|
||||||
|
@ -251,16 +251,16 @@ std::size_t ValueSlice::hash() const {
|
||||||
return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
|
return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ValueSlice::operator==(const Value& other) const {
|
bool ValueSlice::operator==(const Value &other) const {
|
||||||
if (other.isa<ValueSlice>()) {
|
if (other.isa<ValueSlice>()) {
|
||||||
auto other_ = static_cast<const ValueSlice&>(other);
|
auto other_ = static_cast<const ValueSlice &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ValueSlice::operator==(const ValueSlice& other) const {
|
bool ValueSlice::operator==(const ValueSlice &other) const {
|
||||||
MS_EXCEPTION_IF_NULL(start_);
|
MS_EXCEPTION_IF_NULL(start_);
|
||||||
MS_EXCEPTION_IF_NULL(stop_);
|
MS_EXCEPTION_IF_NULL(stop_);
|
||||||
MS_EXCEPTION_IF_NULL(step_);
|
MS_EXCEPTION_IF_NULL(step_);
|
||||||
|
@ -295,16 +295,16 @@ std::size_t KeywordArg::hash() const {
|
||||||
return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()});
|
return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool KeywordArg::operator==(const Value& other) const {
|
bool KeywordArg::operator==(const Value &other) const {
|
||||||
if (other.isa<KeywordArg>()) {
|
if (other.isa<KeywordArg>()) {
|
||||||
auto other_ = static_cast<const KeywordArg&>(other);
|
auto other_ = static_cast<const KeywordArg &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool KeywordArg::operator==(const KeywordArg& other) const { return (other.key_ == key_ && *other.value_ == *value_); }
|
bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); }
|
||||||
|
|
||||||
std::string KeywordArg::ToString() const {
|
std::string KeywordArg::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
|
@ -322,25 +322,25 @@ abstract::AbstractBasePtr KeywordArg::ToAbstract() {
|
||||||
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
|
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ValuePtr ValueDictionary::operator[](const std::string& key) const {
|
const ValuePtr ValueDictionary::operator[](const std::string &key) const {
|
||||||
auto it = std::find_if(key_values_.begin(), key_values_.end(),
|
auto it = std::find_if(key_values_.begin(), key_values_.end(),
|
||||||
[key](const std::pair<std::string, ValuePtr>& item) { return item.first == key; });
|
[key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; });
|
||||||
if (it == key_values_.end()) {
|
if (it == key_values_.end()) {
|
||||||
MS_LOG(EXCEPTION) << "The key " << key << " is not in the map";
|
MS_LOG(EXCEPTION) << "The key " << key << " is not in the map";
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ValueDictionary::operator==(const Value& other) const {
|
bool ValueDictionary::operator==(const Value &other) const {
|
||||||
if (other.isa<ValueDictionary>()) {
|
if (other.isa<ValueDictionary>()) {
|
||||||
auto other_ = static_cast<const ValueDictionary&>(other);
|
auto other_ = static_cast<const ValueDictionary &>(other);
|
||||||
return *this == other_;
|
return *this == other_;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ValueDictionary::operator==(const ValueDictionary& other) const {
|
bool ValueDictionary::operator==(const ValueDictionary &other) const {
|
||||||
if (key_values_.size() != other.key_values_.size()) {
|
if (key_values_.size() != other.key_values_.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -359,12 +359,12 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() {
|
||||||
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
|
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
|
||||||
(void)std::transform(
|
(void)std::transform(
|
||||||
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
|
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
|
||||||
[](const std::pair<std::string, ValuePtr>& item) { return std::make_pair(item.first, item.second->ToAbstract()); });
|
[](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); });
|
||||||
return std::make_shared<abstract::AbstractDictionary>(kv);
|
return std::make_shared<abstract::AbstractDictionary>(kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(
|
REGISTER_PYBIND_DEFINE(
|
||||||
RefKey, ([](const py::module* m) {
|
RefKey, ([](const py::module *m) {
|
||||||
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
|
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
|
||||||
}));
|
}));
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,19 +35,19 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ValueSequeue : public Value {
|
class ValueSequeue : public Value {
|
||||||
public:
|
public:
|
||||||
explicit ValueSequeue(const ValuePtrList& elements) : elements_(elements) {
|
explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) {
|
||||||
TypePtrList t_list;
|
TypePtrList t_list;
|
||||||
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr& ele) {
|
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) {
|
||||||
MS_EXCEPTION_IF_NULL(ele);
|
MS_EXCEPTION_IF_NULL(ele);
|
||||||
return ele->type();
|
return ele->type();
|
||||||
});
|
});
|
||||||
TypePtr t = std::make_shared<Tuple>(t_list);
|
TypePtr t = std::make_shared<Tuple>(t_list);
|
||||||
type_ = t;
|
type_ = t;
|
||||||
}
|
}
|
||||||
ValueSequeue(const std::initializer_list<ValuePtr>& elements) : elements_(elements.begin(), elements.end()) {
|
ValueSequeue(const std::initializer_list<ValuePtr> &elements) : elements_(elements.begin(), elements.end()) {
|
||||||
TypePtrList t_list;
|
TypePtrList t_list;
|
||||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list),
|
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list),
|
||||||
[](const ValuePtr& ele) { return ele->type(); });
|
[](const ValuePtr &ele) { return ele->type(); });
|
||||||
TypePtr t = std::make_shared<Tuple>(t_list);
|
TypePtr t = std::make_shared<Tuple>(t_list);
|
||||||
type_ = t;
|
type_ = t;
|
||||||
}
|
}
|
||||||
|
@ -56,10 +56,10 @@ class ValueSequeue : public Value {
|
||||||
std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); }
|
std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); }
|
||||||
std::size_t size() const { return elements_.size(); }
|
std::size_t size() const { return elements_.size(); }
|
||||||
bool erase(size_t idx);
|
bool erase(size_t idx);
|
||||||
const ValuePtr operator[](const std::size_t& dim) const;
|
const ValuePtr operator[](const std::size_t &dim) const;
|
||||||
const ValuePtrList& value() const { return elements_; }
|
const ValuePtrList &value() const { return elements_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const ValueSequeue& other) const;
|
bool operator==(const ValueSequeue &other) const;
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::string DumpText() const override;
|
std::string DumpText() const override;
|
||||||
|
|
||||||
|
@ -70,8 +70,8 @@ using ValueSequeuePtr = std::shared_ptr<ValueSequeue>;
|
||||||
|
|
||||||
class ValueTuple : public ValueSequeue {
|
class ValueTuple : public ValueSequeue {
|
||||||
public:
|
public:
|
||||||
explicit ValueTuple(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {}
|
explicit ValueTuple(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
|
||||||
ValueTuple(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {}
|
ValueTuple(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
|
||||||
~ValueTuple() override = default;
|
~ValueTuple() override = default;
|
||||||
MS_DECLARE_PARENT(ValueTuple, ValueSequeue)
|
MS_DECLARE_PARENT(ValueTuple, ValueSequeue)
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
@ -83,8 +83,8 @@ using ValueTuplePtr = std::shared_ptr<ValueTuple>;
|
||||||
|
|
||||||
class ValueList : public ValueSequeue {
|
class ValueList : public ValueSequeue {
|
||||||
public:
|
public:
|
||||||
explicit ValueList(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {}
|
explicit ValueList(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
|
||||||
ValueList(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {}
|
ValueList(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
|
||||||
~ValueList() override = default;
|
~ValueList() override = default;
|
||||||
MS_DECLARE_PARENT(ValueList, ValueSequeue)
|
MS_DECLARE_PARENT(ValueList, ValueSequeue)
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
@ -94,7 +94,7 @@ class ValueList : public ValueSequeue {
|
||||||
};
|
};
|
||||||
using ValueListPtr = std::shared_ptr<ValueList>;
|
using ValueListPtr = std::shared_ptr<ValueList>;
|
||||||
|
|
||||||
inline ValuePtr MakeValue(const std::vector<ValuePtr>& v) { return std::make_shared<ValueTuple>(v); }
|
inline ValuePtr MakeValue(const std::vector<ValuePtr> &v) { return std::make_shared<ValueTuple>(v); }
|
||||||
inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); }
|
inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -103,7 +103,7 @@ template <typename T, typename A>
|
||||||
struct is_vector<std::vector<T, A>> : public std::true_type {};
|
struct is_vector<std::vector<T, A>> : public std::true_type {};
|
||||||
|
|
||||||
template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type>
|
template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type>
|
||||||
ValuePtr MakeValue(const T& vec) {
|
ValuePtr MakeValue(const T &vec) {
|
||||||
std::vector<ValuePtr> list;
|
std::vector<ValuePtr> list;
|
||||||
(void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); });
|
(void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); });
|
||||||
return std::make_shared<ValueTuple>(list);
|
return std::make_shared<ValueTuple>(list);
|
||||||
|
@ -111,13 +111,13 @@ ValuePtr MakeValue(const T& vec) {
|
||||||
|
|
||||||
class ValueSlice : public Value {
|
class ValueSlice : public Value {
|
||||||
public:
|
public:
|
||||||
ValueSlice(const ValuePtr& start, const ValuePtr& stop, const ValuePtr& step)
|
ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step)
|
||||||
: start_(start), stop_(stop), step_(step) {}
|
: start_(start), stop_(stop), step_(step) {}
|
||||||
~ValueSlice() override = default;
|
~ValueSlice() override = default;
|
||||||
MS_DECLARE_PARENT(ValueSlice, Value)
|
MS_DECLARE_PARENT(ValueSlice, Value)
|
||||||
std::size_t hash() const override;
|
std::size_t hash() const override;
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const ValueSlice& other) const;
|
bool operator==(const ValueSlice &other) const;
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
|
@ -133,13 +133,13 @@ using ValueSlicePtr = std::shared_ptr<ValueSlice>;
|
||||||
|
|
||||||
class KeywordArg : public Value {
|
class KeywordArg : public Value {
|
||||||
public:
|
public:
|
||||||
KeywordArg(const std::string& key, const ValuePtr& value) : key_(key), value_(value) {}
|
KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {}
|
||||||
~KeywordArg() override = default;
|
~KeywordArg() override = default;
|
||||||
MS_DECLARE_PARENT(KeywordArg, Value)
|
MS_DECLARE_PARENT(KeywordArg, Value)
|
||||||
std::size_t hash() const override;
|
std::size_t hash() const override;
|
||||||
ValuePtr get_value() const { return value_; }
|
ValuePtr get_value() const { return value_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const KeywordArg& other) const;
|
bool operator==(const KeywordArg &other) const;
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
|
@ -154,31 +154,31 @@ using KeywordArgPtr = std::shared_ptr<KeywordArg>;
|
||||||
|
|
||||||
class ValueDictionary : public Value {
|
class ValueDictionary : public Value {
|
||||||
public:
|
public:
|
||||||
explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>>& key_values) : key_values_(key_values) {}
|
explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>> &key_values) : key_values_(key_values) {}
|
||||||
~ValueDictionary() override = default;
|
~ValueDictionary() override = default;
|
||||||
MS_DECLARE_PARENT(ValueDictionary, Value)
|
MS_DECLARE_PARENT(ValueDictionary, Value)
|
||||||
std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); }
|
std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); }
|
||||||
std::size_t size() const { return key_values_.size(); }
|
std::size_t size() const { return key_values_.size(); }
|
||||||
const ValuePtr operator[](const std::string& key) const;
|
const ValuePtr operator[](const std::string &key) const;
|
||||||
const std::vector<std::pair<std::string, ValuePtr>>& value() const { return key_values_; }
|
const std::vector<std::pair<std::string, ValuePtr>> &value() const { return key_values_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const ValueDictionary& other) const;
|
bool operator==(const ValueDictionary &other) const;
|
||||||
|
|
||||||
std::string ToString() const override {
|
std::string ToString() const override {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
std::vector<std::string> keys;
|
std::vector<std::string> keys;
|
||||||
std::vector<ValuePtr> values;
|
std::vector<ValuePtr> values;
|
||||||
for (const auto& kv : key_values_) {
|
for (const auto &kv : key_values_) {
|
||||||
keys.push_back(kv.first);
|
keys.push_back(kv.first);
|
||||||
values.push_back(kv.second);
|
values.push_back(kv.second);
|
||||||
}
|
}
|
||||||
buffer << "(Dict: "
|
buffer << "(Dict: "
|
||||||
<< " keys:(";
|
<< " keys:(";
|
||||||
for (const auto& key : keys) {
|
for (const auto &key : keys) {
|
||||||
buffer << key << ", ";
|
buffer << key << ", ";
|
||||||
}
|
}
|
||||||
buffer << ") values:(";
|
buffer << ") values:(";
|
||||||
for (const auto& value : values) {
|
for (const auto &value : values) {
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
buffer << value->DumpText() << ", ";
|
buffer << value->DumpText() << ", ";
|
||||||
}
|
}
|
||||||
|
@ -195,14 +195,14 @@ using ValueDictionaryPtr = std::shared_ptr<ValueDictionary>;
|
||||||
|
|
||||||
class StringImm : public Value {
|
class StringImm : public Value {
|
||||||
public:
|
public:
|
||||||
explicit StringImm(const std::string& str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {}
|
explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {}
|
||||||
|
|
||||||
~StringImm() override = default;
|
~StringImm() override = default;
|
||||||
MS_DECLARE_PARENT(StringImm, Value)
|
MS_DECLARE_PARENT(StringImm, Value)
|
||||||
std::size_t hash() const override { return hash_; }
|
std::size_t hash() const override { return hash_; }
|
||||||
const std::string& value() const { return str_; }
|
const std::string &value() const { return str_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const StringImm& other) const;
|
bool operator==(const StringImm &other) const;
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
std::string ToString() const override { return str_; }
|
std::string ToString() const override { return str_; }
|
||||||
|
|
||||||
|
@ -218,18 +218,18 @@ class StringImm : public Value {
|
||||||
};
|
};
|
||||||
using StringImmPtr = std::shared_ptr<StringImm>;
|
using StringImmPtr = std::shared_ptr<StringImm>;
|
||||||
IMM_TRAITS(StringImmPtr, std::string)
|
IMM_TRAITS(StringImmPtr, std::string)
|
||||||
IMM_TRAITS(StringImmPtr, const char*)
|
IMM_TRAITS(StringImmPtr, const char *)
|
||||||
|
|
||||||
class RefKey : public Value {
|
class RefKey : public Value {
|
||||||
public:
|
public:
|
||||||
explicit RefKey(const std::string& tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {}
|
explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {}
|
||||||
|
|
||||||
~RefKey() override = default;
|
~RefKey() override = default;
|
||||||
MS_DECLARE_PARENT(RefKey, Value)
|
MS_DECLARE_PARENT(RefKey, Value)
|
||||||
std::size_t hash() const override { return hash_; }
|
std::size_t hash() const override { return hash_; }
|
||||||
const std::string& tag() const { return tag_; }
|
const std::string &tag() const { return tag_; }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
bool operator==(const RefKey& other) const;
|
bool operator==(const RefKey &other) const;
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
std::string ToString() const override { return "RefKey[" + tag_ + "]"; }
|
std::string ToString() const override { return "RefKey[" + tag_ + "]"; }
|
||||||
|
|
||||||
|
@ -251,13 +251,13 @@ class AnyValue : public Value {
|
||||||
~AnyValue() override = default;
|
~AnyValue() override = default;
|
||||||
MS_DECLARE_PARENT(AnyValue, Value)
|
MS_DECLARE_PARENT(AnyValue, Value)
|
||||||
std::size_t hash() const override { return tid(); }
|
std::size_t hash() const override { return tid(); }
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value &other) const override;
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
};
|
};
|
||||||
extern const ValuePtr kAnyValue;
|
extern const ValuePtr kAnyValue;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline const char* GetValue(const ValuePtr& value) {
|
inline const char *GetValue(const ValuePtr &value) {
|
||||||
if (value == nullptr) {
|
if (value == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Value is nullptr";
|
MS_LOG(EXCEPTION) << "Value is nullptr";
|
||||||
}
|
}
|
||||||
|
@ -270,7 +270,7 @@ inline const char* GetValue(const ValuePtr& value) {
|
||||||
|
|
||||||
template <typename T, typename S = typename std::decay<T>::type,
|
template <typename T, typename S = typename std::decay<T>::type,
|
||||||
typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type>
|
typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type>
|
||||||
std::vector<U> GetValue(const ValuePtr& value) {
|
std::vector<U> GetValue(const ValuePtr &value) {
|
||||||
if (value == nullptr) {
|
if (value == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Value is nullptr";
|
MS_LOG(EXCEPTION) << "Value is nullptr";
|
||||||
}
|
}
|
||||||
|
@ -280,21 +280,21 @@ std::vector<U> GetValue(const ValuePtr& value) {
|
||||||
<< ">";
|
<< ">";
|
||||||
}
|
}
|
||||||
std::vector<U> rets;
|
std::vector<U> rets;
|
||||||
const std::vector<ValuePtr>& vals = value->cast<ValueSequeuePtr>()->value();
|
const std::vector<ValuePtr> &vals = value->cast<ValueSequeuePtr>()->value();
|
||||||
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets),
|
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets),
|
||||||
[](const ValuePtr& v) { return GetValue<U>(v); });
|
[](const ValuePtr &v) { return GetValue<U>(v); });
|
||||||
return rets;
|
return rets;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline ValueNodePtr NewValueNode(const ValuePtr& t) { return std::make_shared<ValueNode>(t); }
|
inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared<ValueNode>(t); }
|
||||||
|
|
||||||
template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
|
template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
|
||||||
inline ValueNodePtr NewValueNode(const std::shared_ptr<T>& x) {
|
inline ValueNodePtr NewValueNode(const std::shared_ptr<T> &x) {
|
||||||
return NewValueNode(MakeValue(x));
|
return NewValueNode(MakeValue(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
|
template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
|
||||||
inline ValueNodePtr NewValueNode(const T& x) {
|
inline ValueNodePtr NewValueNode(const T &x) {
|
||||||
return NewValueNode(MakeValue(x));
|
return NewValueNode(MakeValue(x));
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,15 +22,15 @@
|
||||||
#include "optimizer/opt.h"
|
#include "optimizer/opt.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
using VisitFuncType = std::function<void(const AnfNodePtr&)>;
|
using VisitFuncType = std::function<void(const AnfNodePtr &)>;
|
||||||
class AnfVisitor {
|
class AnfVisitor {
|
||||||
public:
|
public:
|
||||||
virtual AnfNodePtr operator()(const opt::OptimizerPtr&, const AnfNodePtr&);
|
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &);
|
||||||
virtual void Visit(const AnfNodePtr&);
|
virtual void Visit(const AnfNodePtr &);
|
||||||
virtual void Visit(const CNodePtr&);
|
virtual void Visit(const CNodePtr &);
|
||||||
virtual void Visit(const ValueNodePtr&);
|
virtual void Visit(const ValueNodePtr &);
|
||||||
virtual void Visit(const ParameterPtr&);
|
virtual void Visit(const ParameterPtr &);
|
||||||
VisitFuncType Match(const PrimitivePtr&, const std::vector<opt::PredicateFuncType>& = {});
|
VisitFuncType Match(const PrimitivePtr &, const std::vector<opt::PredicateFuncType> & = {});
|
||||||
virtual ~AnfVisitor() = default;
|
virtual ~AnfVisitor() = default;
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,12 +26,12 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
void FilterInvaildKernelInfo(const CNodePtr& kernel_node,
|
void FilterInvaildKernelInfo(const CNodePtr &kernel_node,
|
||||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) {
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
|
||||||
(void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
|
(void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
|
||||||
[&](const std::shared_ptr<kernel::KernelBuildInfo>& kernel_build_info) {
|
[&](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
|
||||||
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
|
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
|
||||||
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
|
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
|
||||||
});
|
});
|
||||||
|
@ -46,7 +46,7 @@ void FilterInvaildKernelInfo(const CNodePtr& kernel_node,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
void KernelQuery(const CNodePtr& kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) {
|
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||||
TbeMetadataInfo(kernel_node, kernel_info_list);
|
TbeMetadataInfo(kernel_node, kernel_info_list);
|
||||||
|
|
|
@ -38,11 +38,11 @@ class OpAttr {
|
||||||
std::string value() const { return value_; }
|
std::string value() const { return value_; }
|
||||||
std::string default_value() const { return default_value_; }
|
std::string default_value() const { return default_value_; }
|
||||||
|
|
||||||
void set_name(const std::string& name) { name_ = name; }
|
void set_name(const std::string &name) { name_ = name; }
|
||||||
void set_param_type(const std::string& param_type) { param_type_ = param_type; }
|
void set_param_type(const std::string ¶m_type) { param_type_ = param_type; }
|
||||||
void set_type(const std::string& type) { type_ = type; }
|
void set_type(const std::string &type) { type_ = type; }
|
||||||
void set_value(const std::string& value) { value_ = value; }
|
void set_value(const std::string &value) { value_ = value; }
|
||||||
void set_default_value(const std::string& default_value) { default_value_ = default_value; }
|
void set_default_value(const std::string &default_value) { default_value_ = default_value; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string name_;
|
std::string name_;
|
||||||
|
@ -67,13 +67,13 @@ class OpIOInfo {
|
||||||
std::vector<std::string> formats() const { return formats_; }
|
std::vector<std::string> formats() const { return formats_; }
|
||||||
|
|
||||||
void set_index(const int index) { index_ = index; }
|
void set_index(const int index) { index_ = index; }
|
||||||
void set_name(const std::string& name) { name_ = name; }
|
void set_name(const std::string &name) { name_ = name; }
|
||||||
void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
|
void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
|
||||||
void set_param_type(const std::string& param_type) { param_type_ = param_type; }
|
void set_param_type(const std::string ¶m_type) { param_type_ = param_type; }
|
||||||
void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; }
|
void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; }
|
||||||
void set_shape(const std::string& shape) { shape_ = shape; }
|
void set_shape(const std::string &shape) { shape_ = shape; }
|
||||||
void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; }
|
void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; }
|
||||||
void set_formats(const std::vector<std::string>& formats) { formats_ = formats; }
|
void set_formats(const std::vector<std::string> &formats) { formats_ = formats; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int index_ = 0;
|
int index_ = 0;
|
||||||
|
@ -104,24 +104,24 @@ class OpInfo {
|
||||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
|
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
|
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
|
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
|
||||||
const std::unordered_map<size_t, size_t>& ref_infos() const { return ref_infos_; }
|
const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
|
||||||
|
|
||||||
void set_op_name(const std::string& op_name) { op_name_ = op_name; }
|
void set_op_name(const std::string &op_name) { op_name_ = op_name; }
|
||||||
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
|
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
|
||||||
void set_impl_path(const std::string& impl_path) { impl_path_ = impl_path; }
|
void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
|
||||||
void set_fusion_type(const std::string& fusion_type) { fusion_type_ = fusion_type; }
|
void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; }
|
||||||
void set_async_flag(const bool async_flag) { async_flag_ = async_flag; }
|
void set_async_flag(const bool async_flag) { async_flag_ = async_flag; }
|
||||||
void set_binfile_name(const std::string& binfile_name) { binfile_name_ = binfile_name; }
|
void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; }
|
||||||
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
|
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
|
||||||
void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; }
|
void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
|
||||||
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
|
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
|
||||||
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
|
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
|
||||||
void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; }
|
void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; }
|
||||||
void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); }
|
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
|
||||||
void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); }
|
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
|
||||||
void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); }
|
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
|
||||||
void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& inputs) { inputs_ptr_ = inputs; }
|
void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; }
|
||||||
void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& outputs) { outputs_ptr_ = outputs; }
|
void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { outputs_ptr_ = outputs; }
|
||||||
bool is_ref() const { return !ref_infos_.empty(); }
|
bool is_ref() const { return !ref_infos_.empty(); }
|
||||||
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
|
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
|
||||||
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
|
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
|
||||||
|
|
|
@ -67,7 +67,7 @@ std::string ImplTypeToStr(OpImplyType impl_type) {
|
||||||
return "unknow";
|
return "unknow";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) {
|
bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) {
|
||||||
bool ret = false;
|
bool ret = false;
|
||||||
try {
|
try {
|
||||||
auto op_json = nlohmann::json::parse(json_string);
|
auto op_json = nlohmann::json::parse(json_string);
|
||||||
|
@ -88,13 +88,13 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path)
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string;
|
MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string;
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(DEBUG) << "get op_json elements failed:" << e.what();
|
MS_LOG(DEBUG) << "get op_json elements failed:" << e.what();
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info) {
|
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
|
||||||
op_info->set_async_flag(obj.at(kAsyncFlag));
|
op_info->set_async_flag(obj.at(kAsyncFlag));
|
||||||
op_info->set_binfile_name(obj.at(kBinfileName));
|
op_info->set_binfile_name(obj.at(kBinfileName));
|
||||||
op_info->set_compute_cost(obj.at(kComputeCost));
|
op_info->set_compute_cost(obj.at(kComputeCost));
|
||||||
|
@ -108,8 +108,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type,
|
bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
|
||||||
const std::string& impl_path) {
|
const std::string &impl_path) {
|
||||||
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
|
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
op_info->set_op_name(obj.at(kOpName));
|
op_info->set_op_name(obj.at(kOpName));
|
||||||
|
@ -120,7 +120,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
|
||||||
DecodeTBESpecificInfo(obj, op_info);
|
DecodeTBESpecificInfo(obj, op_info);
|
||||||
}
|
}
|
||||||
auto attrs = obj.at(kAttr);
|
auto attrs = obj.at(kAttr);
|
||||||
for (const auto& attr : attrs) {
|
for (const auto &attr : attrs) {
|
||||||
if (!DecodeAttr(attr, imply_type, op_info)) {
|
if (!DecodeAttr(attr, imply_type, op_info)) {
|
||||||
MS_LOG(DEBUG) << "DecodeAttr Failed";
|
MS_LOG(DEBUG) << "DecodeAttr Failed";
|
||||||
return false;
|
return false;
|
||||||
|
@ -131,14 +131,14 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
|
||||||
dtype_format = obj.at(kDtypeFormat);
|
dtype_format = obj.at(kDtypeFormat);
|
||||||
}
|
}
|
||||||
auto inputs = obj.at(kIputs);
|
auto inputs = obj.at(kIputs);
|
||||||
for (const auto& input : inputs) {
|
for (const auto &input : inputs) {
|
||||||
if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
|
if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
|
||||||
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
|
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto outputs = obj.at(kOutputs);
|
auto outputs = obj.at(kOutputs);
|
||||||
for (const auto& output : outputs) {
|
for (const auto &output : outputs) {
|
||||||
if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
|
if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
|
||||||
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
|
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
|
||||||
return false;
|
return false;
|
||||||
|
@ -156,8 +156,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
|
bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
|
||||||
const std::shared_ptr<OpInfo>& op_info) {
|
const std::shared_ptr<OpInfo> &op_info) {
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
bool ret = true;
|
bool ret = true;
|
||||||
try {
|
try {
|
||||||
|
@ -175,34 +175,34 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
|
||||||
op_attr->set_default_value(obj.at(kDefaultValue));
|
op_attr->set_default_value(obj.at(kDefaultValue));
|
||||||
}
|
}
|
||||||
op_info->add_attrs_ptr(op_attr);
|
op_info->add_attrs_ptr(op_attr);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what();
|
MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what();
|
||||||
ret = false;
|
ret = false;
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
|
bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
|
||||||
size_t index) {
|
size_t index) {
|
||||||
bool ret = true;
|
bool ret = true;
|
||||||
try {
|
try {
|
||||||
std::vector<std::string> dtype;
|
std::vector<std::string> dtype;
|
||||||
std::vector<std::string> format;
|
std::vector<std::string> format;
|
||||||
for (const auto& it : dtype_format) {
|
for (const auto &it : dtype_format) {
|
||||||
dtype.emplace_back(it[index][0]);
|
dtype.emplace_back(it[index][0]);
|
||||||
format.emplace_back(it[index][1]);
|
format.emplace_back(it[index][1]);
|
||||||
}
|
}
|
||||||
op_io->set_dtypes(dtype);
|
op_io->set_dtypes(dtype);
|
||||||
op_io->set_formats(format);
|
op_io->set_formats(format);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what();
|
MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what();
|
||||||
ret = false;
|
ret = false;
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
|
bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
|
||||||
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) {
|
const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) {
|
||||||
bool ret = true;
|
bool ret = true;
|
||||||
try {
|
try {
|
||||||
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
|
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
|
||||||
|
@ -243,14 +243,14 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply
|
||||||
} else if (io_type == kOutput) {
|
} else if (io_type == kOutput) {
|
||||||
op_info->add_outputs_ptr(op_io);
|
op_info->add_outputs_ptr(op_io);
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what();
|
MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what();
|
||||||
ret = false;
|
ret = false;
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType imply_type) {
|
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) {
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
bool is_gpu = (context->device_target() == kGPUDevice);
|
bool is_gpu = (context->device_target() == kGPUDevice);
|
||||||
|
@ -260,7 +260,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
|
||||||
<< ", current op num:" << op_info_.size();
|
<< ", current op num:" << op_info_.size();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
for (const auto& op_info : op_info_) {
|
for (const auto &op_info : op_info_) {
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) {
|
if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) {
|
||||||
return op_info;
|
return op_info;
|
||||||
|
@ -271,14 +271,14 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) {
|
bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
const auto& output_infos = op_info->outputs_ptr();
|
const auto &output_infos = op_info->outputs_ptr();
|
||||||
const auto& input_infos = op_info->inputs_ptr();
|
const auto &input_infos = op_info->inputs_ptr();
|
||||||
for (size_t out_index = 0; out_index < output_infos.size(); out_index++) {
|
for (size_t out_index = 0; out_index < output_infos.size(); out_index++) {
|
||||||
const auto& out_name = output_infos[out_index]->name();
|
const auto &out_name = output_infos[out_index]->name();
|
||||||
for (size_t in_index = 0; in_index < input_infos.size(); in_index++) {
|
for (size_t in_index = 0; in_index < input_infos.size(); in_index++) {
|
||||||
const auto& in_name = input_infos[in_index]->name();
|
const auto &in_name = input_infos[in_index]->name();
|
||||||
if (out_name == in_name) {
|
if (out_name == in_name) {
|
||||||
if (op_info->has_ref_index(out_index)) {
|
if (op_info->has_ref_index(out_index)) {
|
||||||
MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info";
|
MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info";
|
||||||
|
@ -293,9 +293,9 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo>& op_info) {
|
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
for (const auto& exist_op_info : op_info_) {
|
for (const auto &exist_op_info : op_info_) {
|
||||||
MS_EXCEPTION_IF_NULL(exist_op_info);
|
MS_EXCEPTION_IF_NULL(exist_op_info);
|
||||||
if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() &&
|
if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() &&
|
||||||
exist_op_info->impl_path() != op_info->impl_path()) {
|
exist_op_info->impl_path() != op_info->impl_path()) {
|
||||||
|
|
|
@ -28,23 +28,23 @@ class OpLib {
|
||||||
public:
|
public:
|
||||||
OpLib() = default;
|
OpLib() = default;
|
||||||
virtual ~OpLib() = default;
|
virtual ~OpLib() = default;
|
||||||
bool RegOp(const std::string& json_string, const std::string& impl_path);
|
bool RegOp(const std::string &json_string, const std::string &impl_path);
|
||||||
static std::shared_ptr<OpInfo> FindOp(const std::string& op_name, OpImplyType imply_type);
|
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static std::vector<std::shared_ptr<OpInfo>> op_info_;
|
static std::vector<std::shared_ptr<OpInfo>> op_info_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path);
|
static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path);
|
||||||
static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
|
static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
|
||||||
const std::shared_ptr<OpInfo>& op_info);
|
const std::shared_ptr<OpInfo> &op_info);
|
||||||
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
|
static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
|
||||||
size_t index);
|
size_t index);
|
||||||
static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info);
|
static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info);
|
||||||
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
|
static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
|
||||||
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format);
|
const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format);
|
||||||
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info);
|
static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info);
|
||||||
static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info);
|
static bool CheckRepetition(const std::shared_ptr<OpInfo> &op_info);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,6 +19,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
// cppcheck-suppress unusedFunction
|
// cppcheck-suppress unusedFunction
|
||||||
std::string set_version(const std::string& version) { return version; }
|
std::string set_version(const std::string &version) { return version; }
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -42,11 +42,11 @@ struct OpMergedInfo {
|
||||||
};
|
};
|
||||||
|
|
||||||
using GenAttrFuncType =
|
using GenAttrFuncType =
|
||||||
std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto*, const PrimitivePtr&)>;
|
std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto *, const PrimitivePtr &)>;
|
||||||
|
|
||||||
template <typename T, size_t rep_cnt = 0>
|
template <typename T, size_t rep_cnt = 0>
|
||||||
void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type,
|
void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
|
||||||
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) {
|
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
|
||||||
auto casted_value = dyn_cast<T>(value);
|
auto casted_value = dyn_cast<T>(value);
|
||||||
if (casted_value == nullptr) {
|
if (casted_value == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
|
MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
|
||||||
|
@ -76,8 +76,8 @@ void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeTy
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t beg_idx = 0>
|
template <size_t beg_idx = 0>
|
||||||
void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type,
|
void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
|
||||||
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) {
|
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
|
||||||
auto tuple_ptr = dyn_cast<ValueTuple>(value);
|
auto tuple_ptr = dyn_cast<ValueTuple>(value);
|
||||||
if (tuple_ptr == nullptr) {
|
if (tuple_ptr == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed.";
|
MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed.";
|
||||||
|
@ -99,8 +99,8 @@ void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_Attrib
|
||||||
attr_proto->set_type(attr_type);
|
attr_proto->set_type(attr_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType,
|
void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType,
|
||||||
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) {
|
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
|
||||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||||
auto attr_value = GetValue<std::string>(value);
|
auto attr_value = GetValue<std::string>(value);
|
||||||
if (attr_value == "VALID") {
|
if (attr_value == "VALID") {
|
||||||
|
@ -112,16 +112,16 @@ void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType
|
||||||
|
|
||||||
class OpAttrInfo {
|
class OpAttrInfo {
|
||||||
public:
|
public:
|
||||||
OpAttrInfo(const std::string& attr_name, const string& onnx_attr_name,
|
OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name,
|
||||||
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr)
|
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr)
|
||||||
: attr_name_(attr_name),
|
: attr_name_(attr_name),
|
||||||
onnx_attr_name_(onnx_attr_name),
|
onnx_attr_name_(onnx_attr_name),
|
||||||
onnx_attr_type_(onnx_attr_type),
|
onnx_attr_type_(onnx_attr_type),
|
||||||
fn_gen_attr_(fn_gen_attr) {}
|
fn_gen_attr_(fn_gen_attr) {}
|
||||||
~OpAttrInfo() {}
|
~OpAttrInfo() {}
|
||||||
|
|
||||||
const std::string& attr_name() const { return attr_name_; }
|
const std::string &attr_name() const { return attr_name_; }
|
||||||
const std::string& onnx_attr_name() const { return onnx_attr_name_; }
|
const std::string &onnx_attr_name() const { return onnx_attr_name_; }
|
||||||
onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; }
|
onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; }
|
||||||
GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; }
|
GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; }
|
||||||
|
|
||||||
|
@ -134,27 +134,27 @@ class OpAttrInfo {
|
||||||
|
|
||||||
class OpNameInfo {
|
class OpNameInfo {
|
||||||
public:
|
public:
|
||||||
OpNameInfo& set_op_type(const std::string& op_type) {
|
OpNameInfo &set_op_type(const std::string &op_type) {
|
||||||
op_type_ = op_type;
|
op_type_ = op_type;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string& op_type() const { return op_type_; }
|
const std::string &op_type() const { return op_type_; }
|
||||||
|
|
||||||
OpNameInfo& set_onnx_type(const std::string& onnx_type) {
|
OpNameInfo &set_onnx_type(const std::string &onnx_type) {
|
||||||
onnx_type_ = onnx_type;
|
onnx_type_ = onnx_type;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string& onnx_type() const { return onnx_type_; }
|
const std::string &onnx_type() const { return onnx_type_; }
|
||||||
|
|
||||||
OpNameInfo& Attr(const std::string& attr_name, const std::string& onnx_attr_name,
|
OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name,
|
||||||
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) {
|
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) {
|
||||||
op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr));
|
op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<OpAttrInfo>& op_attrs() const { return op_attrs_; }
|
const std::vector<OpAttrInfo> &op_attrs() const { return op_attrs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string op_type_; // operator type of MindSpore
|
std::string op_type_; // operator type of MindSpore
|
||||||
|
@ -183,8 +183,8 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
||||||
.Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>)
|
.Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>)
|
||||||
.Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
|
.Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
|
||||||
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING,
|
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING,
|
||||||
[](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto* const attr_proto,
|
[](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto,
|
||||||
const PrimitivePtr& prim) {
|
const PrimitivePtr &prim) {
|
||||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||||
auto attr_value = GetValue<std::string>(value);
|
auto attr_value = GetValue<std::string>(value);
|
||||||
if (attr_value == "valid") {
|
if (attr_value == "valid") {
|
||||||
|
@ -220,7 +220,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax,
|
||||||
SetAttrValueToProto<Int32Imm>)
|
SetAttrValueToProto<Int32Imm>)
|
||||||
.Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT,
|
.Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT,
|
||||||
[](ValuePtr, onnx::AttributeProto_AttributeType,
|
[](ValuePtr, onnx::AttributeProto_AttributeType,
|
||||||
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) {
|
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
|
||||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||||
attr_proto->set_i(0);
|
attr_proto->set_i(0);
|
||||||
}))
|
}))
|
||||||
|
@ -242,7 +242,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
||||||
|
|
||||||
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
||||||
|
|
||||||
void RegisterOpConverters(const std::function<void(OpNameInfo&&)>& fn) {
|
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)());
|
fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)());
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(Mul)());
|
fn(OP_CONVERT_FUNCTION_NAME(Mul)());
|
||||||
|
|
||||||
|
@ -265,16 +265,16 @@ class OpConvertRegistry {
|
||||||
public:
|
public:
|
||||||
~OpConvertRegistry() { Clear(); }
|
~OpConvertRegistry() { Clear(); }
|
||||||
|
|
||||||
static void RegisterOneOpConverter(OpNameInfo&& op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; }
|
static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; }
|
||||||
|
|
||||||
static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); }
|
static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); }
|
||||||
|
|
||||||
static OpConvertRegistry& GetSingleton() {
|
static OpConvertRegistry &GetSingleton() {
|
||||||
static OpConvertRegistry registry = OpConvertRegistry();
|
static OpConvertRegistry registry = OpConvertRegistry();
|
||||||
return registry;
|
return registry;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const std::unordered_map<std::string, OpNameInfo>& GetOpConvertMap() { return GetSingleton().op_map_; }
|
static const std::unordered_map<std::string, OpNameInfo> &GetOpConvertMap() { return GetSingleton().op_map_; }
|
||||||
|
|
||||||
void Clear() noexcept { op_map_.clear(); }
|
void Clear() noexcept { op_map_.clear(); }
|
||||||
|
|
||||||
|
@ -289,59 +289,59 @@ class OnnxExporter {
|
||||||
OnnxExporter() {}
|
OnnxExporter() {}
|
||||||
~OnnxExporter() {}
|
~OnnxExporter() {}
|
||||||
|
|
||||||
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
|
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitModelInfo();
|
void InitModelInfo();
|
||||||
|
|
||||||
void ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto);
|
void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto);
|
||||||
void ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto);
|
void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
size_t ExportPrimitive(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs,
|
const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
|
||||||
onnx::GraphProto* graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
|
static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
|
||||||
void SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* value_proto, bool is_output = false);
|
void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false);
|
||||||
void SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* tensor_proto);
|
void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto);
|
||||||
|
|
||||||
void MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes,
|
void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
|
||||||
std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr);
|
std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr);
|
||||||
void ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
void ExportPrimReshape(const FuncGraphPtr& func_graph, const CNodePtr& node,
|
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto);
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimReduceMean(const FuncGraphPtr& func_graph, const CNodePtr& node,
|
void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto);
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimCast(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimPReLU(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
void ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
void ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
void ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node,
|
void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto);
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
void ExportOutput(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
std::string GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
std::string GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* const graph_proto);
|
onnx::GraphProto *const graph_proto);
|
||||||
|
|
||||||
void ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* tensor_proto);
|
void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto);
|
||||||
void SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* node_proto);
|
void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto);
|
||||||
|
|
||||||
size_t AllocateNodeIndex() { return ++onnx_node_index_; }
|
size_t AllocateNodeIndex() { return ++onnx_node_index_; }
|
||||||
|
|
||||||
void ResetNodeIndex() { onnx_node_index_ = 0; }
|
void ResetNodeIndex() { onnx_node_index_ = 0; }
|
||||||
|
|
||||||
static int GetInt32Value(const AnfNodePtr& node) {
|
static int GetInt32Value(const AnfNodePtr &node) {
|
||||||
auto value_node_ptr = dyn_cast<ValueNode>(node);
|
auto value_node_ptr = dyn_cast<ValueNode>(node);
|
||||||
MS_EXCEPTION_IF_NULL(value_node_ptr);
|
MS_EXCEPTION_IF_NULL(value_node_ptr);
|
||||||
return GetValue<int>(value_node_ptr->value());
|
return GetValue<int>(value_node_ptr->value());
|
||||||
|
@ -352,7 +352,7 @@ class OnnxExporter {
|
||||||
size_t onnx_node_index_ = 0;
|
size_t onnx_node_index_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) {
|
std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -360,7 +360,7 @@ std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) {
|
||||||
OpConvertRegistry::GetSingleton().Clear();
|
OpConvertRegistry::GetSingleton().Clear();
|
||||||
OpConvertRegistry::RegisterAllOpConverters();
|
OpConvertRegistry::RegisterAllOpConverters();
|
||||||
InitModelInfo();
|
InitModelInfo();
|
||||||
onnx::GraphProto* graph_proto = model_.mutable_graph();
|
onnx::GraphProto *graph_proto = model_.mutable_graph();
|
||||||
ExportFuncGraph(func_graph, graph_proto);
|
ExportFuncGraph(func_graph, graph_proto);
|
||||||
return model_.SerializeAsString();
|
return model_.SerializeAsString();
|
||||||
}
|
}
|
||||||
|
@ -369,11 +369,11 @@ void OnnxExporter::InitModelInfo() {
|
||||||
model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
|
model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
|
||||||
model_.set_producer_name("MindSpore");
|
model_.set_producer_name("MindSpore");
|
||||||
model_.set_producer_version("1.0");
|
model_.set_producer_version("1.0");
|
||||||
onnx::OperatorSetIdProto* opset_proto = model_.add_opset_import();
|
onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import();
|
||||||
opset_proto->set_version(9);
|
opset_proto->set_version(9);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) {
|
void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
||||||
std::map<AnfNodePtr, size_t> node_map;
|
std::map<AnfNodePtr, size_t> node_map;
|
||||||
|
|
||||||
onnx_node_index_ = func_graph->parameters().size();
|
onnx_node_index_ = func_graph->parameters().size();
|
||||||
|
@ -390,14 +390,14 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphPr
|
||||||
ExportNodes(func_graph, &node_map, graph_proto);
|
ExportNodes(func_graph, &node_map, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) {
|
void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
||||||
for (auto& param : func_graph->parameters()) {
|
for (auto ¶m : func_graph->parameters()) {
|
||||||
const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
|
const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
|
||||||
if (param_ptr == nullptr) {
|
if (param_ptr == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
|
MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
|
||||||
}
|
}
|
||||||
|
|
||||||
onnx::ValueInfoProto* input_proto = graph_proto->add_input();
|
onnx::ValueInfoProto *input_proto = graph_proto->add_input();
|
||||||
input_proto->set_name(param_ptr->ToString());
|
input_proto->set_name(param_ptr->ToString());
|
||||||
SetValueInfoType(param_ptr, input_proto);
|
SetValueInfoType(param_ptr, input_proto);
|
||||||
|
|
||||||
|
@ -405,7 +405,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphP
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// parameter with default value is an ONNX initializer
|
// parameter with default value is an ONNX initializer
|
||||||
onnx::TensorProto* initializer_proto = graph_proto->add_initializer();
|
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
|
||||||
initializer_proto->set_name(param_ptr->ToString());
|
initializer_proto->set_name(param_ptr->ToString());
|
||||||
SetTensorProtoInfo(param_ptr, initializer_proto);
|
SetTensorProtoInfo(param_ptr, initializer_proto);
|
||||||
// set value for initializer
|
// set value for initializer
|
||||||
|
@ -445,25 +445,25 @@ onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) {
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* const value_proto, bool is_output) {
|
void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) {
|
||||||
auto dtype = node->Type();
|
auto dtype = node->Type();
|
||||||
auto shape = node->Shape();
|
auto shape = node->Shape();
|
||||||
onnx::TypeProto* type_proto = value_proto->mutable_type();
|
onnx::TypeProto *type_proto = value_proto->mutable_type();
|
||||||
if (dtype->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
if (dtype->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
||||||
auto tensor = dyn_cast<TensorType>(dtype);
|
auto tensor = dyn_cast<TensorType>(dtype);
|
||||||
auto elem_type = tensor->element();
|
auto elem_type = tensor->element();
|
||||||
const auto& dims = dyn_cast<abstract::Shape>(shape)->shape();
|
const auto &dims = dyn_cast<abstract::Shape>(shape)->shape();
|
||||||
// output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64
|
// output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64
|
||||||
auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id());
|
auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id());
|
||||||
type_proto->mutable_tensor_type()->set_elem_type(type);
|
type_proto->mutable_tensor_type()->set_elem_type(type);
|
||||||
|
|
||||||
for (const auto& dim : dims) {
|
for (const auto &dim : dims) {
|
||||||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* const tensor_proto) {
|
void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) {
|
||||||
auto dtype = param->Type();
|
auto dtype = param->Type();
|
||||||
auto shape = param->Shape();
|
auto shape = param->Shape();
|
||||||
if (!dtype->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
|
if (!dtype->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
|
||||||
|
@ -472,18 +472,18 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorPro
|
||||||
|
|
||||||
auto tensor = dyn_cast<TensorType>(dtype);
|
auto tensor = dyn_cast<TensorType>(dtype);
|
||||||
auto elem_type = tensor->element();
|
auto elem_type = tensor->element();
|
||||||
const auto& dims = dyn_cast<abstract::Shape>(shape)->shape();
|
const auto &dims = dyn_cast<abstract::Shape>(shape)->shape();
|
||||||
tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id()));
|
tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id()));
|
||||||
for (const auto& dim : dims) {
|
for (const auto &dim : dims) {
|
||||||
tensor_proto->add_dims(dim);
|
tensor_proto->add_dims(dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes,
|
void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
|
||||||
std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr) {
|
std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) {
|
||||||
std::unordered_map<AnfNodePtr, OpMergedInfo>& op_merged_infos = *op_merged_infos_ptr;
|
std::unordered_map<AnfNodePtr, OpMergedInfo> &op_merged_infos = *op_merged_infos_ptr;
|
||||||
|
|
||||||
for (auto& node : nodes) {
|
for (auto &node : nodes) {
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -492,7 +492,7 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto
|
||||||
// if the key `input` does not exist, just create a new one
|
// if the key `input` does not exist, just create a new one
|
||||||
op_merged_infos[cnode].referred_count += 1;
|
op_merged_infos[cnode].referred_count += 1;
|
||||||
}
|
}
|
||||||
for (auto& input : cnode->inputs()) {
|
for (auto &input : cnode->inputs()) {
|
||||||
if (!input->isa<CNode>()) {
|
if (!input->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -527,14 +527,14 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto
|
||||||
* | +-- Parameter
|
* | +-- Parameter
|
||||||
* | `-- ValueNode
|
* | `-- ValueNode
|
||||||
*/
|
*/
|
||||||
void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* const graph_proto) {
|
onnx::GraphProto *const graph_proto) {
|
||||||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||||
|
|
||||||
std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos;
|
std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos;
|
||||||
MatchAndMark(func_graph, nodes, &op_merged_infos);
|
MatchAndMark(func_graph, nodes, &op_merged_infos);
|
||||||
|
|
||||||
for (const AnfNodePtr& node : nodes) {
|
for (const AnfNodePtr &node : nodes) {
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -570,20 +570,20 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodeP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
|
void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
auto input_shape = node->input(2);
|
auto input_shape = node->input(2);
|
||||||
std::string name_shape;
|
std::string name_shape;
|
||||||
if (input_shape->isa<ValueNode>()) {
|
if (input_shape->isa<ValueNode>()) {
|
||||||
auto const_node_idx = AllocateNodeIndex();
|
auto const_node_idx = AllocateNodeIndex();
|
||||||
(*node_map_ptr)[input_shape] = const_node_idx;
|
(*node_map_ptr)[input_shape] = const_node_idx;
|
||||||
onnx::NodeProto* node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
name_shape = std::to_string(const_node_idx);
|
name_shape = std::to_string(const_node_idx);
|
||||||
node_proto->add_output(name_shape);
|
node_proto->add_output(name_shape);
|
||||||
|
|
||||||
node_proto->set_op_type("Constant");
|
node_proto->set_op_type("Constant");
|
||||||
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
attr_proto->set_name("value");
|
attr_proto->set_name("value");
|
||||||
|
|
||||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||||
|
@ -595,28 +595,28 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const C
|
||||||
|
|
||||||
auto node_idx = AllocateNodeIndex();
|
auto node_idx = AllocateNodeIndex();
|
||||||
(*node_map_ptr)[node] = node_idx;
|
(*node_map_ptr)[node] = node_idx;
|
||||||
onnx::NodeProto* node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
node_proto->set_op_type(prim::kPrimReshape->name());
|
node_proto->set_op_type(prim::kPrimReshape->name());
|
||||||
node_proto->add_output(std::to_string(node_idx));
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
node_proto->add_input(name_x);
|
node_proto->add_input(name_x);
|
||||||
node_proto->add_input(name_shape);
|
node_proto->add_input(name_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
|
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr,
|
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* const graph_proto) {
|
onnx::GraphProto *const graph_proto) {
|
||||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
auto input_axis = node->input(2);
|
auto input_axis = node->input(2);
|
||||||
|
|
||||||
auto node_idx = AllocateNodeIndex();
|
auto node_idx = AllocateNodeIndex();
|
||||||
(*node_map_ptr)[node] = node_idx;
|
(*node_map_ptr)[node] = node_idx;
|
||||||
onnx::NodeProto* node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
node_proto->set_op_type(prim::kPrimReduceMean->name());
|
node_proto->set_op_type(prim::kPrimReduceMean->name());
|
||||||
node_proto->add_output(std::to_string(node_idx));
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
node_proto->add_input(input_data);
|
node_proto->add_input(input_data);
|
||||||
|
|
||||||
if (input_axis->isa<ValueNode>()) {
|
if (input_axis->isa<ValueNode>()) {
|
||||||
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
attr_proto->set_name("axes");
|
attr_proto->set_name("axes");
|
||||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
||||||
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
|
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
|
||||||
|
@ -630,20 +630,20 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, cons
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
|
void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
auto input_type = node->input(2);
|
auto input_type = node->input(2);
|
||||||
|
|
||||||
auto node_idx = AllocateNodeIndex();
|
auto node_idx = AllocateNodeIndex();
|
||||||
(*node_map_ptr)[node] = node_idx;
|
(*node_map_ptr)[node] = node_idx;
|
||||||
onnx::NodeProto* node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
node_proto->set_op_type(prim::kPrimCast->name());
|
node_proto->set_op_type(prim::kPrimCast->name());
|
||||||
node_proto->add_output(std::to_string(node_idx));
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
node_proto->add_input(input_data);
|
node_proto->add_input(input_data);
|
||||||
|
|
||||||
if (input_type->isa<ValueNode>()) {
|
if (input_type->isa<ValueNode>()) {
|
||||||
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
attr_proto->set_name("to");
|
attr_proto->set_name("to");
|
||||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||||
auto type_value = dyn_cast<ValueNode>(input_type)->value();
|
auto type_value = dyn_cast<ValueNode>(input_type)->value();
|
||||||
|
@ -655,8 +655,8 @@ void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
|
void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
|
auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
|
||||||
|
|
||||||
|
@ -668,11 +668,11 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo
|
||||||
// format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2]
|
// format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2]
|
||||||
if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) {
|
if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) {
|
||||||
auto node_idx = AllocateNodeIndex();
|
auto node_idx = AllocateNodeIndex();
|
||||||
onnx::NodeProto* node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
node_proto->set_op_type("Unsqueeze");
|
node_proto->set_op_type("Unsqueeze");
|
||||||
node_proto->add_output(std::to_string(node_idx));
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
|
|
||||||
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
||||||
attr_proto->set_name("axes");
|
attr_proto->set_name("axes");
|
||||||
attr_proto->add_ints(1);
|
attr_proto->add_ints(1);
|
||||||
|
@ -684,15 +684,15 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo
|
||||||
|
|
||||||
auto node_idx = AllocateNodeIndex();
|
auto node_idx = AllocateNodeIndex();
|
||||||
(*node_map_ptr)[node] = node_idx;
|
(*node_map_ptr)[node] = node_idx;
|
||||||
onnx::NodeProto* node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
node_proto->set_op_type("PRelu");
|
node_proto->set_op_type("PRelu");
|
||||||
node_proto->add_output(std::to_string(node_idx));
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
node_proto->add_input(input_x);
|
node_proto->add_input(input_x);
|
||||||
node_proto->add_input(input_slope);
|
node_proto->add_input(input_slope);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node,
|
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
|
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
|
||||||
if (node->IsApply(prim::kPrimReshape)) {
|
if (node->IsApply(prim::kPrimReshape)) {
|
||||||
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
|
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
|
||||||
|
@ -735,31 +735,31 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& n
|
||||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs,
|
const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
|
||||||
onnx::GraphProto* const graph_proto) {
|
onnx::GraphProto *const graph_proto) {
|
||||||
auto op_map = OpConvertRegistry::GetOpConvertMap();
|
auto op_map = OpConvertRegistry::GetOpConvertMap();
|
||||||
auto op_iter = op_map.find(prim->name());
|
auto op_iter = op_map.find(prim->name());
|
||||||
if (op_iter == op_map.end()) {
|
if (op_iter == op_map.end()) {
|
||||||
MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map";
|
MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map";
|
||||||
}
|
}
|
||||||
const OpNameInfo& op_convert_info = op_iter->second;
|
const OpNameInfo &op_convert_info = op_iter->second;
|
||||||
|
|
||||||
auto node_idx = AllocateNodeIndex();
|
auto node_idx = AllocateNodeIndex();
|
||||||
|
|
||||||
onnx::NodeProto* node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
node_proto->add_output(std::to_string(node_idx));
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
node_proto->set_op_type(op_convert_info.onnx_type());
|
node_proto->set_op_type(op_convert_info.onnx_type());
|
||||||
|
|
||||||
// Set inputs
|
// Set inputs
|
||||||
for (const auto& input : inputs) {
|
for (const auto &input : inputs) {
|
||||||
auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto);
|
auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto);
|
||||||
node_proto->add_input(input_name);
|
node_proto->add_input(input_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set node attribute
|
// Set node attribute
|
||||||
for (const OpAttrInfo& attr : op_convert_info.op_attrs()) {
|
for (const OpAttrInfo &attr : op_convert_info.op_attrs()) {
|
||||||
const std::string& attr_name = attr.attr_name();
|
const std::string &attr_name = attr.attr_name();
|
||||||
ValuePtr attr_value = nullptr;
|
ValuePtr attr_value = nullptr;
|
||||||
if (!attr_name.empty()) {
|
if (!attr_name.empty()) {
|
||||||
attr_value = prim->GetAttr(attr_name);
|
attr_value = prim->GetAttr(attr_name);
|
||||||
|
@ -767,15 +767,15 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::ma
|
||||||
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name;
|
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
onnx::AttributeProto* onnx_attr_proto = node_proto->add_attribute();
|
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
|
||||||
onnx_attr_proto->set_name(attr.onnx_attr_name());
|
onnx_attr_proto->set_name(attr.onnx_attr_name());
|
||||||
attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim);
|
attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim);
|
||||||
}
|
}
|
||||||
return node_idx;
|
return node_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node,
|
void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
auto conv_node = dyn_cast<CNode>(node->input(1));
|
auto conv_node = dyn_cast<CNode>(node->input(1));
|
||||||
auto input_x = conv_node->input(1); // conv input x
|
auto input_x = conv_node->input(1); // conv input x
|
||||||
auto input_w = conv_node->input(2); // conv weight(filter)
|
auto input_w = conv_node->input(2); // conv weight(filter)
|
||||||
|
@ -786,8 +786,8 @@ void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePt
|
||||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto);
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node,
|
void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
auto matmul_node = dyn_cast<CNode>(node->input(1));
|
auto matmul_node = dyn_cast<CNode>(node->input(1));
|
||||||
auto input_x = matmul_node->input(1); // matmul input x
|
auto input_x = matmul_node->input(1); // matmul input x
|
||||||
auto input_y = matmul_node->input(2); // matmul input y
|
auto input_y = matmul_node->input(2); // matmul input y
|
||||||
|
@ -798,9 +798,9 @@ void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePt
|
||||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node,
|
void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr,
|
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* const graph_proto) {
|
onnx::GraphProto *const graph_proto) {
|
||||||
auto batch_norm_node = dyn_cast<CNode>(node->input(1));
|
auto batch_norm_node = dyn_cast<CNode>(node->input(1));
|
||||||
|
|
||||||
PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(0)))->value());
|
PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(0)))->value());
|
||||||
|
@ -811,20 +811,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CN
|
||||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportOutput(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
|
void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
if (node->inputs().size() != 2) {
|
if (node->inputs().size() != 2) {
|
||||||
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
|
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
|
||||||
}
|
}
|
||||||
AnfNodePtr arg = node->input(1);
|
AnfNodePtr arg = node->input(1);
|
||||||
std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto);
|
std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto);
|
||||||
onnx::ValueInfoProto* output_proto = graph_proto->add_output();
|
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||||
output_proto->set_name(name);
|
output_proto->set_name(name);
|
||||||
SetValueInfoType(arg, output_proto, false);
|
SetValueInfoType(arg, output_proto, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
|
std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto* const graph_proto) {
|
onnx::GraphProto *const graph_proto) {
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
auto iter = node_map_ptr->find(node);
|
auto iter = node_map_ptr->find(node);
|
||||||
if (iter == node_map_ptr->end()) {
|
if (iter == node_map_ptr->end()) {
|
||||||
|
@ -848,7 +848,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN
|
||||||
(*node_map_ptr)[node] = node_idx;
|
(*node_map_ptr)[node] = node_idx;
|
||||||
std::string node_name = std::to_string(node_idx);
|
std::string node_name = std::to_string(node_idx);
|
||||||
|
|
||||||
onnx::NodeProto* node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
node_proto->add_output(node_name);
|
node_proto->add_output(node_name);
|
||||||
|
|
||||||
SetNodeAttribute(node->cast<ValueNodePtr>()->value(), node_proto);
|
SetNodeAttribute(node->cast<ValueNodePtr>()->value(), node_proto);
|
||||||
|
@ -859,7 +859,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN
|
||||||
MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name();
|
MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* const tensor_proto) {
|
void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) {
|
||||||
auto tuple_ptr = dyn_cast<ValueTuple>(value);
|
auto tuple_ptr = dyn_cast<ValueTuple>(value);
|
||||||
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||||
if (tuple_ptr->size() == 0) {
|
if (tuple_ptr->size() == 0) {
|
||||||
|
@ -891,14 +891,14 @@ void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* const node_proto) {
|
void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) {
|
||||||
node_proto->set_op_type("Constant");
|
node_proto->set_op_type("Constant");
|
||||||
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
attr_proto->set_name("value");
|
attr_proto->set_name("value");
|
||||||
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
|
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) {
|
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
|
||||||
OnnxExporter exporter;
|
OnnxExporter exporter;
|
||||||
return exporter.GetOnnxProtoString(func_graph);
|
return exporter.GetOnnxProtoString(func_graph);
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,12 +32,12 @@ enum class DataType { kInt, kFloat, kDouble, kUnknown };
|
||||||
|
|
||||||
// Whether has a T type data in AnyPtrList.
|
// Whether has a T type data in AnyPtrList.
|
||||||
template <class T>
|
template <class T>
|
||||||
bool HasType(const AnyPtrList& list) {
|
bool HasType(const AnyPtrList &list) {
|
||||||
bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr& ptr) { return ptr->is<T>(); });
|
bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
DataType InferType(const AnyPtrList& list) {
|
DataType InferType(const AnyPtrList &list) {
|
||||||
if (HasType<double>(list)) {
|
if (HasType<double>(list)) {
|
||||||
return DataType::kDouble;
|
return DataType::kDouble;
|
||||||
} else if (HasType<float>(list)) {
|
} else if (HasType<float>(list)) {
|
||||||
|
@ -180,7 +180,7 @@ bool InnerScalarGe(T x, U y) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#define SCALAR_OP(op_t) \
|
#define SCALAR_OP(op_t) \
|
||||||
ValuePtr Scalar##op_t(const ValuePtrList& list) { \
|
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
|
||||||
do { \
|
do { \
|
||||||
if (list.size() < 2) { \
|
if (list.size() < 2) { \
|
||||||
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
|
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
|
||||||
|
@ -223,7 +223,7 @@ SCALAR_OP(Pow)
|
||||||
SCALAR_OP(Floordiv)
|
SCALAR_OP(Floordiv)
|
||||||
|
|
||||||
#define LOGIC_OP(op_t) \
|
#define LOGIC_OP(op_t) \
|
||||||
ValuePtr Scalar##op_t(const ValuePtrList& list) { \
|
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
|
||||||
if (list.size() < 2) { \
|
if (list.size() < 2) { \
|
||||||
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
|
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
|
||||||
} \
|
} \
|
||||||
|
@ -274,7 +274,7 @@ LOGIC_OP(Ne)
|
||||||
LOGIC_OP(Le)
|
LOGIC_OP(Le)
|
||||||
LOGIC_OP(Ge)
|
LOGIC_OP(Ge)
|
||||||
|
|
||||||
ValuePtr ScalarUAdd(const ValuePtrList& list) {
|
ValuePtr ScalarUAdd(const ValuePtrList &list) {
|
||||||
if (list.size() != 1) {
|
if (list.size() != 1) {
|
||||||
MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size();
|
MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size();
|
||||||
}
|
}
|
||||||
|
@ -283,7 +283,7 @@ ValuePtr ScalarUAdd(const ValuePtrList& list) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr ScalarUSub(const ValuePtrList& list) {
|
ValuePtr ScalarUSub(const ValuePtrList &list) {
|
||||||
if (list.size() != 1) {
|
if (list.size() != 1) {
|
||||||
MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size();
|
MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size();
|
||||||
}
|
}
|
||||||
|
@ -302,7 +302,7 @@ ValuePtr ScalarUSub(const ValuePtrList& list) {
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << ".";
|
MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr ScalarLog(const ValuePtrList& list) {
|
ValuePtr ScalarLog(const ValuePtrList &list) {
|
||||||
if (list.empty()) {
|
if (list.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty.";
|
MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty.";
|
||||||
}
|
}
|
||||||
|
@ -321,7 +321,7 @@ ValuePtr ScalarLog(const ValuePtrList& list) {
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString();
|
MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr BoolNot(const ValuePtrList& list) {
|
ValuePtr BoolNot(const ValuePtrList &list) {
|
||||||
if (list.empty()) {
|
if (list.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "value list of BoolNot is empty";
|
MS_LOG(EXCEPTION) << "value list of BoolNot is empty";
|
||||||
}
|
}
|
||||||
|
@ -337,7 +337,7 @@ ValuePtr BoolNot(const ValuePtrList& list) {
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString();
|
MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr BoolAnd(const ValuePtrList& list) {
|
ValuePtr BoolAnd(const ValuePtrList &list) {
|
||||||
if (list.size() < 2) {
|
if (list.size() < 2) {
|
||||||
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2.";
|
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2.";
|
||||||
}
|
}
|
||||||
|
@ -356,7 +356,7 @@ ValuePtr BoolAnd(const ValuePtrList& list) {
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << ".";
|
MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr BoolOr(const ValuePtrList& list) {
|
ValuePtr BoolOr(const ValuePtrList &list) {
|
||||||
if (list.size() < 2) {
|
if (list.size() < 2) {
|
||||||
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2.";
|
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2.";
|
||||||
}
|
}
|
||||||
|
@ -375,7 +375,7 @@ ValuePtr BoolOr(const ValuePtrList& list) {
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << ".";
|
MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr BoolEq(const ValuePtrList& list) {
|
ValuePtr BoolEq(const ValuePtrList &list) {
|
||||||
if (list.size() < 2) {
|
if (list.size() < 2) {
|
||||||
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2.";
|
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,29 +29,29 @@ namespace prim {
|
||||||
using Any = mindspore::Any;
|
using Any = mindspore::Any;
|
||||||
using AnyPtrList = std::vector<std::shared_ptr<Any>>;
|
using AnyPtrList = std::vector<std::shared_ptr<Any>>;
|
||||||
using ValuePtrList = std::vector<ValuePtr>;
|
using ValuePtrList = std::vector<ValuePtr>;
|
||||||
using OpsFunction = std::function<Any(const AnyPtrList&)>;
|
using OpsFunction = std::function<Any(const AnyPtrList &)>;
|
||||||
using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr>&)>;
|
using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr> &)>;
|
||||||
|
|
||||||
ValuePtr ScalarAdd(const ValuePtrList& list);
|
ValuePtr ScalarAdd(const ValuePtrList &list);
|
||||||
ValuePtr ScalarSub(const ValuePtrList& list);
|
ValuePtr ScalarSub(const ValuePtrList &list);
|
||||||
ValuePtr ScalarMul(const ValuePtrList& list);
|
ValuePtr ScalarMul(const ValuePtrList &list);
|
||||||
ValuePtr ScalarDiv(const ValuePtrList& list);
|
ValuePtr ScalarDiv(const ValuePtrList &list);
|
||||||
ValuePtr ScalarMod(const ValuePtrList& list);
|
ValuePtr ScalarMod(const ValuePtrList &list);
|
||||||
ValuePtr ScalarPow(const ValuePtrList& list);
|
ValuePtr ScalarPow(const ValuePtrList &list);
|
||||||
ValuePtr ScalarFloordiv(const ValuePtrList& list);
|
ValuePtr ScalarFloordiv(const ValuePtrList &list);
|
||||||
ValuePtr ScalarUAdd(const ValuePtrList& list);
|
ValuePtr ScalarUAdd(const ValuePtrList &list);
|
||||||
ValuePtr ScalarUSub(const ValuePtrList& list);
|
ValuePtr ScalarUSub(const ValuePtrList &list);
|
||||||
ValuePtr ScalarLog(const ValuePtrList& list);
|
ValuePtr ScalarLog(const ValuePtrList &list);
|
||||||
ValuePtr ScalarEq(const ValuePtrList& list);
|
ValuePtr ScalarEq(const ValuePtrList &list);
|
||||||
ValuePtr ScalarLt(const ValuePtrList& list);
|
ValuePtr ScalarLt(const ValuePtrList &list);
|
||||||
ValuePtr ScalarGt(const ValuePtrList& list);
|
ValuePtr ScalarGt(const ValuePtrList &list);
|
||||||
ValuePtr ScalarNe(const ValuePtrList& list);
|
ValuePtr ScalarNe(const ValuePtrList &list);
|
||||||
ValuePtr ScalarLe(const ValuePtrList& list);
|
ValuePtr ScalarLe(const ValuePtrList &list);
|
||||||
ValuePtr ScalarGe(const ValuePtrList& list);
|
ValuePtr ScalarGe(const ValuePtrList &list);
|
||||||
ValuePtr BoolNot(const ValuePtrList& list);
|
ValuePtr BoolNot(const ValuePtrList &list);
|
||||||
ValuePtr BoolAnd(const ValuePtrList& list);
|
ValuePtr BoolAnd(const ValuePtrList &list);
|
||||||
ValuePtr BoolOr(const ValuePtrList& list);
|
ValuePtr BoolOr(const ValuePtrList &list);
|
||||||
ValuePtr BoolEq(const ValuePtrList& list);
|
ValuePtr BoolEq(const ValuePtrList &list);
|
||||||
std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2);
|
std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2);
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -66,7 +66,7 @@ const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail");
|
||||||
// Apply a function of two arguments cumulatively to the items of a sequence,
|
// Apply a function of two arguments cumulatively to the items of a sequence,
|
||||||
// from left to right, so as to reduce the sequence to a single value.For example,
|
// from left to right, so as to reduce the sequence to a single value.For example,
|
||||||
// reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
|
// reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
|
||||||
AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) {
|
AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) {
|
||||||
std::shared_ptr<Any> ret;
|
std::shared_ptr<Any> ret;
|
||||||
size_t size = list.size();
|
size_t size = list.size();
|
||||||
if (size < 2) {
|
if (size < 2) {
|
||||||
|
@ -88,7 +88,7 @@ AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr Reduce(const AnfNodeOpsFunction& func, const std::vector<AnfNodePtr>& list) {
|
AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) {
|
||||||
size_t size = list.size();
|
size_t size = list.size();
|
||||||
if (size < 2) {
|
if (size < 2) {
|
||||||
MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
|
MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
|
||||||
|
@ -121,7 +121,7 @@ void HyperMap::Init() {
|
||||||
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
|
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
|
||||||
}
|
}
|
||||||
|
|
||||||
HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf)
|
HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
|
||||||
: MetaFuncGraph("hyper_map"),
|
: MetaFuncGraph("hyper_map"),
|
||||||
fn_leaf_(fn_leaf),
|
fn_leaf_(fn_leaf),
|
||||||
broadcast_(false),
|
broadcast_(false),
|
||||||
|
@ -129,13 +129,13 @@ HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf)
|
||||||
Init();
|
Init();
|
||||||
}
|
}
|
||||||
|
|
||||||
HyperMap::HyperMap(const HyperMap& h)
|
HyperMap::HyperMap(const HyperMap &h)
|
||||||
: MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
|
: MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
|
||||||
Init();
|
Init();
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
|
AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||||
const ArgsPairList& arg_map) {
|
const ArgsPairList &arg_map) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
std::vector<AnfNodePtr> inputs;
|
std::vector<AnfNodePtr> inputs;
|
||||||
if (fn_arg != nullptr) {
|
if (fn_arg != nullptr) {
|
||||||
|
@ -145,17 +145,17 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const Anf
|
||||||
}
|
}
|
||||||
|
|
||||||
(void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
|
(void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
|
||||||
[](const std::pair<AnfNodePtr, Any>& item) { return item.first; });
|
[](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
|
||||||
return func_graph->NewCNode(inputs);
|
return func_graph->NewCNode(inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph,
|
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
|
||||||
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) {
|
const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
|
||||||
std::size_t size = type->elements().size();
|
std::size_t size = type->elements().size();
|
||||||
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) {
|
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||||
auto lhs = std::static_pointer_cast<List>(item.second);
|
auto lhs = std::static_pointer_cast<List>(item.second);
|
||||||
MS_EXCEPTION_IF_NULL(lhs);
|
MS_EXCEPTION_IF_NULL(lhs);
|
||||||
return lhs->elements().size() != size;
|
return lhs->elements().size() != size;
|
||||||
|
@ -179,7 +179,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph
|
||||||
|
|
||||||
(void)std::transform(
|
(void)std::transform(
|
||||||
arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
|
arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
|
||||||
[&func_graph, i](const std::pair<AnfNodePtr, Any>& item) {
|
[&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
|
||||||
return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
|
return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -188,13 +188,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph
|
||||||
return func_graph->NewCNode(inputs);
|
return func_graph->NewCNode(inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph,
|
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
|
||||||
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) {
|
const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
|
||||||
std::size_t size = type->elements().size();
|
std::size_t size = type->elements().size();
|
||||||
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) {
|
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||||
auto lhs = std::static_pointer_cast<Tuple>(item.second);
|
auto lhs = std::static_pointer_cast<Tuple>(item.second);
|
||||||
MS_EXCEPTION_IF_NULL(lhs);
|
MS_EXCEPTION_IF_NULL(lhs);
|
||||||
return lhs->elements().size() != size;
|
return lhs->elements().size() != size;
|
||||||
|
@ -226,8 +226,8 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGrap
|
||||||
return func_graph->NewCNode(inputs);
|
return func_graph->NewCNode(inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph,
|
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
|
||||||
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) {
|
const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
|
||||||
|
@ -257,11 +257,11 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGrap
|
||||||
return func_graph->NewCNode(inputs);
|
return func_graph->NewCNode(inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) {
|
AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
|
||||||
bool found = false;
|
bool found = false;
|
||||||
TypeId id = kObjectTypeEnd;
|
TypeId id = kObjectTypeEnd;
|
||||||
std::pair<AnfNodePtr, TypePtr> pair;
|
std::pair<AnfNodePtr, TypePtr> pair;
|
||||||
for (auto& item : arg_map) {
|
for (auto &item : arg_map) {
|
||||||
pair = item;
|
pair = item;
|
||||||
id = item.second->type_id();
|
id = item.second->type_id();
|
||||||
if (nonleaf_.count(id)) {
|
if (nonleaf_.count(id)) {
|
||||||
|
@ -272,7 +272,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
|
||||||
|
|
||||||
if (found) {
|
if (found) {
|
||||||
// In a nonleaf situation, all arguments must have the same generic.
|
// In a nonleaf situation, all arguments must have the same generic.
|
||||||
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr>& item) {
|
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||||
if (item.first != pair.first) {
|
if (item.first != pair.first) {
|
||||||
return item.second->type_id() != pair.second->type_id();
|
return item.second->type_id() != pair.second->type_id();
|
||||||
}
|
}
|
||||||
|
@ -283,7 +283,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
|
||||||
oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
|
oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
|
||||||
<< trace::GetDebugInfo(func_graph->debug_info()) << "\n";
|
<< trace::GetDebugInfo(func_graph->debug_info()) << "\n";
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
for (auto& item : arg_map) {
|
for (auto &item : arg_map) {
|
||||||
oss << ++idx << ": " << item.second->ToString() << "\n";
|
oss << ++idx << ": " << item.second->ToString() << "\n";
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
|
MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
|
||||||
|
@ -308,14 +308,14 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairList& args_spec_list) {
|
ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
|
||||||
TypePtr type_tensor = std::make_shared<TensorType>();
|
TypePtr type_tensor = std::make_shared<TensorType>();
|
||||||
bool flag = std::any_of(
|
bool flag = std::any_of(
|
||||||
args_spec_list.begin(), args_spec_list.end(),
|
args_spec_list.begin(), args_spec_list.end(),
|
||||||
[type_tensor](const std::pair<AnfNodePtr, TypePtr>& item) { return IsSubType(item.second, type_tensor); });
|
[type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
|
||||||
if (flag && broadcast_) {
|
if (flag && broadcast_) {
|
||||||
ArgsPairList ret;
|
ArgsPairList ret;
|
||||||
for (auto& item : args_spec_list) {
|
for (auto &item : args_spec_list) {
|
||||||
if (!IsSubType(item.second, type_tensor)) {
|
if (!IsSubType(item.second, type_tensor)) {
|
||||||
TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
|
TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
|
||||||
ret.push_back(
|
ret.push_back(
|
||||||
|
@ -329,7 +329,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairL
|
||||||
return args_spec_list;
|
return args_spec_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) {
|
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
|
||||||
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
||||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
ptrGraph->debug_info()->set_name("hyper_map");
|
ptrGraph->debug_info()->set_name("hyper_map");
|
||||||
|
@ -353,7 +353,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) {
|
||||||
return ptrGraph;
|
return ptrGraph;
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& args_spec_list) const {
|
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||||
if (fn_leaf_ == nullptr) {
|
if (fn_leaf_ == nullptr) {
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||||
// Assert that hypermap's function param does not contain free variables
|
// Assert that hypermap's function param does not contain free variables
|
||||||
|
@ -368,20 +368,20 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList&
|
||||||
|
|
||||||
AbstractBasePtrList broadened;
|
AbstractBasePtrList broadened;
|
||||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
|
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
|
||||||
[](const AbstractBasePtr& arg) -> AbstractBasePtr {
|
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||||
MS_EXCEPTION_IF_NULL(arg);
|
MS_EXCEPTION_IF_NULL(arg);
|
||||||
return arg->Broaden();
|
return arg->Broaden();
|
||||||
});
|
});
|
||||||
return broadened;
|
return broadened;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
|
||||||
(void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
|
(void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
|
||||||
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
|
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
|
||||||
.def(py::init<>());
|
.def(py::init<>());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple) {
|
FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) {
|
||||||
MS_EXCEPTION_IF_NULL(a_tuple);
|
MS_EXCEPTION_IF_NULL(a_tuple);
|
||||||
|
|
||||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||||
|
@ -401,7 +401,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tu
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list) {
|
FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) {
|
||||||
MS_EXCEPTION_IF_NULL(a_list);
|
MS_EXCEPTION_IF_NULL(a_list);
|
||||||
|
|
||||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||||
|
@ -421,7 +421,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
if (args_spec_list.size() != 1) {
|
if (args_spec_list.size() != 1) {
|
||||||
MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
|
MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
|
||||||
}
|
}
|
||||||
|
@ -441,11 +441,11 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list)
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(
|
REGISTER_PYBIND_DEFINE(
|
||||||
Tail_, ([](const py::module* m) {
|
Tail_, ([](const py::module *m) {
|
||||||
(void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string&>());
|
(void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
int tuple_size = SizeToInt(args_spec_list.size());
|
int tuple_size = SizeToInt(args_spec_list.size());
|
||||||
|
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
|
@ -486,7 +486,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& arg
|
||||||
return fg;
|
return fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_list, bool sens_param)
|
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
|
||||||
: MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
|
: MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
|
||||||
if (get_by_list) {
|
if (get_by_list) {
|
||||||
signatures_ =
|
signatures_ =
|
||||||
|
@ -496,8 +496,8 @@ GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights,
|
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
||||||
const std::vector<AnfNodePtr>& params_list, bool applyJ) {
|
const std::vector<AnfNodePtr> ¶ms_list, bool applyJ) {
|
||||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||||
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
|
||||||
|
@ -537,7 +537,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights,
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
|
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
|
||||||
ValueNodePtr opsTupleItem) {
|
ValueNodePtr opsTupleItem) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
|
||||||
|
@ -590,7 +590,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, An
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the graph.
|
// Generate the graph.
|
||||||
FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
if (args_spec_list.size() < 1) {
|
if (args_spec_list.size() < 1) {
|
||||||
MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
|
MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
|
||||||
<< args_spec_list.size() << ".";
|
<< args_spec_list.size() << ".";
|
||||||
|
@ -637,21 +637,21 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_sp
|
||||||
return dfBuilder;
|
return dfBuilder;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
|
||||||
(void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
|
(void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
|
||||||
*m, "GradOperation_")
|
*m, "GradOperation_")
|
||||||
.def(py::init<std::string&>(), py::arg("fn"))
|
.def(py::init<std::string &>(), py::arg("fn"))
|
||||||
.def(py::init<std::string&, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
|
.def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
|
||||||
py::arg("get_by_list"), py::arg("sens_param"));
|
py::arg("get_by_list"), py::arg("sens_param"));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
MultitypeFuncGraph::MultitypeFuncGraph(const std::string& name) : MetaFuncGraph(name) {
|
MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
|
||||||
fn_cache_.clear();
|
fn_cache_.clear();
|
||||||
signatures_ = std::vector<Signature>({// def multitype(*args:ref):
|
signatures_ = std::vector<Signature>({// def multitype(*args:ref):
|
||||||
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
|
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) {
|
void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
|
||||||
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
|
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
|
||||||
auto fn = fn_cache_.find(types);
|
auto fn = fn_cache_.find(types);
|
||||||
if (fn != fn_cache_.end()) {
|
if (fn != fn_cache_.end()) {
|
||||||
|
@ -660,7 +660,7 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn)
|
||||||
fn_cache_[types] = s_fn;
|
fn_cache_[types] = s_fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& py_fn) {
|
void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
|
||||||
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
|
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
|
||||||
auto fn = fn_cache_.find(types);
|
auto fn = fn_cache_.find(types);
|
||||||
if (fn != fn_cache_.end()) {
|
if (fn != fn_cache_.end()) {
|
||||||
|
@ -669,9 +669,9 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function&
|
||||||
fn_cache_py_[types] = py_fn;
|
fn_cache_py_[types] = py_fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, const py::function& py_fn) {
|
void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) {
|
||||||
TypePtrList types;
|
TypePtrList types;
|
||||||
for (auto& type_name : types_name) {
|
for (auto &type_name : types_name) {
|
||||||
auto type_ptr = StringToType(type_name);
|
auto type_ptr = StringToType(type_name);
|
||||||
if (type_ptr == nullptr) {
|
if (type_ptr == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error ";
|
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error ";
|
||||||
|
@ -681,7 +681,7 @@ void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, co
|
||||||
Register(types, py_fn);
|
Register(types, py_fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& py_fn) {
|
void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
|
||||||
std::vector<std::string> types_name;
|
std::vector<std::string> types_name;
|
||||||
for (size_t it = 0; it < tuple.size(); ++it) {
|
for (size_t it = 0; it < tuple.size(); ++it) {
|
||||||
py::object name_py = tuple[it];
|
py::object name_py = tuple[it];
|
||||||
|
@ -693,16 +693,16 @@ void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function&
|
||||||
}
|
}
|
||||||
Register(types_name, py_fn);
|
Register(types_name, py_fn);
|
||||||
}
|
}
|
||||||
static TypePtr UnwrapRef(const TypePtr& type) {
|
static TypePtr UnwrapRef(const TypePtr &type) {
|
||||||
if (type->isa<RefType>()) {
|
if (type->isa<RefType>()) {
|
||||||
return type->cast<RefTypePtr>()->subtype();
|
return type->cast<RefTypePtr>()->subtype();
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) {
|
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
||||||
bool find_fn = false;
|
bool find_fn = false;
|
||||||
py::function py_fn;
|
py::function py_fn;
|
||||||
for (auto& item : fn_cache_py_) {
|
for (auto &item : fn_cache_py_) {
|
||||||
TypePtrList sign = item.first;
|
TypePtrList sign = item.first;
|
||||||
if (sign.size() != types.size()) {
|
if (sign.size() != types.size()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -735,7 +735,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) {
|
||||||
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
|
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
|
||||||
<< "`, corresponding location info:\n";
|
<< "`, corresponding location info:\n";
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
for (auto& item : fn_cache_py_) {
|
for (auto &item : fn_cache_py_) {
|
||||||
FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
|
FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
|
MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
|
||||||
|
@ -747,15 +747,15 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) {
|
||||||
<< oss.str();
|
<< oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
|
||||||
(void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
|
(void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
|
||||||
*m, "MultitypeFuncGraph_")
|
*m, "MultitypeFuncGraph_")
|
||||||
.def(py::init<std::string&>())
|
.def(py::init<std::string &>())
|
||||||
.def("register_fn", &MultitypeFuncGraph::PyRegister);
|
.def("register_fn", &MultitypeFuncGraph::PyRegister);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Generate the ListMap func graph.
|
// Generate the ListMap func graph.
|
||||||
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
size_t args_num = args_spec_list.size();
|
size_t args_num = args_spec_list.size();
|
||||||
// args: fn, list1, list2, ...
|
// args: fn, list1, list2, ...
|
||||||
if (args_num < 2) {
|
if (args_num < 2) {
|
||||||
|
@ -821,8 +821,8 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_lis
|
||||||
return fg_ptr;
|
return fg_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgnext_ptr,
|
void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
|
||||||
const FuncGraphPtr& fg_ptr) {
|
const FuncGraphPtr &fg_ptr) {
|
||||||
MS_EXCEPTION_IF_NULL(fg_ptr);
|
MS_EXCEPTION_IF_NULL(fg_ptr);
|
||||||
|
|
||||||
AnfNodePtr fn = fg_ptr->add_parameter();
|
AnfNodePtr fn = fg_ptr->add_parameter();
|
||||||
|
@ -858,8 +858,8 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr&
|
||||||
fgtrue_ptr->set_output(output_cnode);
|
fgtrue_ptr->set_output(output_cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgcond_ptr,
|
void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
|
||||||
const FuncGraphPtr& fg_ptr) {
|
const FuncGraphPtr &fg_ptr) {
|
||||||
MS_EXCEPTION_IF_NULL(fg_ptr);
|
MS_EXCEPTION_IF_NULL(fg_ptr);
|
||||||
AnfNodePtr fn = fg_ptr->add_parameter();
|
AnfNodePtr fn = fg_ptr->add_parameter();
|
||||||
|
|
||||||
|
@ -893,7 +893,7 @@ void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr&
|
||||||
fg_ptr->set_output(output_cnode);
|
fg_ptr->set_output(output_cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// args: tuple1, tuple2
|
// args: tuple1, tuple2
|
||||||
abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
|
abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
|
||||||
AbstractBasePtr abs_a = args_spec_list[0];
|
AbstractBasePtr abs_a = args_spec_list[0];
|
||||||
|
@ -928,7 +928,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_li
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
int GetArgScalarValue(const abstract::AbstractScalarPtr& scalar, const std::string&) {
|
int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
|
||||||
MS_EXCEPTION_IF_NULL(scalar);
|
MS_EXCEPTION_IF_NULL(scalar);
|
||||||
return GetValue<int>(scalar->BuildValue());
|
return GetValue<int>(scalar->BuildValue());
|
||||||
}
|
}
|
||||||
|
@ -942,7 +942,7 @@ int GetPositiveIndex(int index, int length) {
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std::string& member_name) {
|
int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) {
|
||||||
MS_EXCEPTION_IF_NULL(member);
|
MS_EXCEPTION_IF_NULL(member);
|
||||||
|
|
||||||
if (member->isa<AbstractScalar>()) {
|
if (member->isa<AbstractScalar>()) {
|
||||||
|
@ -957,8 +957,8 @@ int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std
|
||||||
<< member->ToString();
|
<< member->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSlicePtr& slice, int* start_index,
|
void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index,
|
||||||
int* stop_index, int* step_value) {
|
int *stop_index, int *step_value) {
|
||||||
MS_EXCEPTION_IF_NULL(tuple);
|
MS_EXCEPTION_IF_NULL(tuple);
|
||||||
MS_EXCEPTION_IF_NULL(slice);
|
MS_EXCEPTION_IF_NULL(slice);
|
||||||
MS_EXCEPTION_IF_NULL(start_index);
|
MS_EXCEPTION_IF_NULL(start_index);
|
||||||
|
@ -998,7 +998,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// slice a tuple
|
// slice a tuple
|
||||||
// args: tuple, start index, end index, step
|
// args: tuple, start index, end index, step
|
||||||
const std::string op_name("TupleSlice");
|
const std::string op_name("TupleSlice");
|
||||||
|
@ -1032,7 +1032,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) {
|
int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) {
|
||||||
unsigned int number_dec = 0;
|
unsigned int number_dec = 0;
|
||||||
for (size_t index = 0; index < number_bin.size(); index++) {
|
for (size_t index = 0; index < number_bin.size(); index++) {
|
||||||
number_dec |= number_bin[index] << index;
|
number_dec |= number_bin[index] << index;
|
||||||
|
@ -1040,8 +1040,8 @@ int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) {
|
||||||
return static_cast<int>(number_dec);
|
return static_cast<int>(number_dec);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vector<int>* end,
|
void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end,
|
||||||
std::vector<int>* strides, int length) {
|
std::vector<int> *strides, int length) {
|
||||||
MS_EXCEPTION_IF_NULL(slice);
|
MS_EXCEPTION_IF_NULL(slice);
|
||||||
MS_EXCEPTION_IF_NULL(begin);
|
MS_EXCEPTION_IF_NULL(begin);
|
||||||
MS_EXCEPTION_IF_NULL(end);
|
MS_EXCEPTION_IF_NULL(end);
|
||||||
|
@ -1064,8 +1064,8 @@ void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vec
|
||||||
strides->push_back(step_value);
|
strides->push_back(step_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, const std::vector<int>& shape,
|
int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape,
|
||||||
std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) {
|
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
|
||||||
MS_EXCEPTION_IF_NULL(slice_tuple);
|
MS_EXCEPTION_IF_NULL(slice_tuple);
|
||||||
MS_EXCEPTION_IF_NULL(begin);
|
MS_EXCEPTION_IF_NULL(begin);
|
||||||
MS_EXCEPTION_IF_NULL(end);
|
MS_EXCEPTION_IF_NULL(end);
|
||||||
|
@ -1111,8 +1111,8 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple,
|
||||||
return ConvertBinaryToDecimal(shrink);
|
return ConvertBinaryToDecimal(shrink);
|
||||||
}
|
}
|
||||||
|
|
||||||
int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const std::vector<int>& shape,
|
int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape,
|
||||||
std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) {
|
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
|
||||||
MS_EXCEPTION_IF_NULL(begin);
|
MS_EXCEPTION_IF_NULL(begin);
|
||||||
MS_EXCEPTION_IF_NULL(end);
|
MS_EXCEPTION_IF_NULL(end);
|
||||||
MS_EXCEPTION_IF_NULL(strides);
|
MS_EXCEPTION_IF_NULL(strides);
|
||||||
|
@ -1132,9 +1132,9 @@ int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, const std::vector<int>& shape,
|
int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape,
|
||||||
std::vector<int>* begin, std::vector<int>* end,
|
std::vector<int> *begin, std::vector<int> *end,
|
||||||
std::vector<int>* strides) {
|
std::vector<int> *strides) {
|
||||||
MS_EXCEPTION_IF_NULL(begin);
|
MS_EXCEPTION_IF_NULL(begin);
|
||||||
MS_EXCEPTION_IF_NULL(end);
|
MS_EXCEPTION_IF_NULL(end);
|
||||||
MS_EXCEPTION_IF_NULL(strides);
|
MS_EXCEPTION_IF_NULL(strides);
|
||||||
|
@ -1153,7 +1153,7 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, co
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// slice a tensor
|
// slice a tensor
|
||||||
// args: tensor, slice or slice tuple
|
// args: tensor, slice or slice tuple
|
||||||
const std::string op_name = std::string("TensorSlice");
|
const std::string op_name = std::string("TensorSlice");
|
||||||
|
@ -1177,7 +1177,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec
|
||||||
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream args_info;
|
std::ostringstream args_info;
|
||||||
for (const auto& arg : args_spec_list) {
|
for (const auto &arg : args_spec_list) {
|
||||||
MS_EXCEPTION_IF_NULL(arg);
|
MS_EXCEPTION_IF_NULL(arg);
|
||||||
args_info << arg->ToString() << "\n";
|
args_info << arg->ToString() << "\n";
|
||||||
}
|
}
|
||||||
|
@ -1199,19 +1199,19 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec
|
||||||
return ret_graph;
|
return ret_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(
|
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
|
||||||
TupleAdd_, ([](const py::module* m) {
|
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
|
||||||
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>());
|
.def(py::init<std::string &>());
|
||||||
}));
|
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module* m) {
|
|
||||||
(void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
|
|
||||||
.def(py::init<std::string&>());
|
|
||||||
}));
|
}));
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
|
||||||
|
(void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
|
||||||
|
.def(py::init<std::string &>());
|
||||||
|
}));
|
||||||
|
|
||||||
|
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
|
||||||
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
|
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
|
||||||
.def(py::init<std::string&>());
|
.def(py::init<std::string &>());
|
||||||
}));
|
}));
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -47,20 +47,20 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
|
||||||
|
|
||||||
class MultitypeFuncGraph : public MetaFuncGraph {
|
class MultitypeFuncGraph : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit MultitypeFuncGraph(const std::string& name);
|
explicit MultitypeFuncGraph(const std::string &name);
|
||||||
~MultitypeFuncGraph() override = default;
|
~MultitypeFuncGraph() override = default;
|
||||||
MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph)
|
MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph)
|
||||||
|
|
||||||
using specialize_fn = FuncGraph* (*)(TypePtrList);
|
using specialize_fn = FuncGraph *(*)(TypePtrList);
|
||||||
// Register a method which specialize based on types vectors;
|
// Register a method which specialize based on types vectors;
|
||||||
virtual void Register(const TypePtrList& types, specialize_fn s_fn);
|
virtual void Register(const TypePtrList &types, specialize_fn s_fn);
|
||||||
virtual void Register(const TypePtrList& types, const py::function& py_fn);
|
virtual void Register(const TypePtrList &types, const py::function &py_fn);
|
||||||
virtual void Register(const std::vector<std::string>& types_name, const py::function& py_fn);
|
virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn);
|
||||||
virtual void PyRegister(const py::tuple& tuple, const py::function& py_fn);
|
virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn);
|
||||||
|
|
||||||
FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override;
|
FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override;
|
||||||
size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); }
|
size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); }
|
||||||
const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual>& GetPyFunctions() const {
|
const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const {
|
||||||
return fn_cache_py_;
|
return fn_cache_py_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,10 +72,10 @@ using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>;
|
||||||
|
|
||||||
class HyperMap : public MetaFuncGraph {
|
class HyperMap : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr);
|
explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
|
||||||
HyperMap(const HyperMap& h);
|
HyperMap(const HyperMap &h);
|
||||||
void Init();
|
void Init();
|
||||||
HyperMap& operator=(const HyperMap& h) {
|
HyperMap &operator=(const HyperMap &h) {
|
||||||
if (this != &h) {
|
if (this != &h) {
|
||||||
fn_leaf_ = h.fn_leaf_;
|
fn_leaf_ = h.fn_leaf_;
|
||||||
broadcast_ = h.broadcast_;
|
broadcast_ = h.broadcast_;
|
||||||
|
@ -89,21 +89,21 @@ class HyperMap : public MetaFuncGraph {
|
||||||
~HyperMap() override = default;
|
~HyperMap() override = default;
|
||||||
MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
|
MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
|
||||||
|
|
||||||
abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const override;
|
abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
|
||||||
FuncGraphPtr GenerateFromTypes(const TypePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
|
||||||
MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
|
MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
|
AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||||
const ArgsPairList& arg_map);
|
const ArgsPairList &arg_map);
|
||||||
AnfNodePtr FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
|
AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||||
const ArgsPairList& arg_map);
|
const ArgsPairList &arg_map);
|
||||||
AnfNodePtr FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
|
AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||||
const ArgsPairList& arg_map);
|
const ArgsPairList &arg_map);
|
||||||
AnfNodePtr FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
|
AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||||
const ArgsPairList& arg_map);
|
const ArgsPairList &arg_map);
|
||||||
AnfNodePtr Make(const FuncGraphPtr& graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map);
|
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
|
||||||
ArgsPairList Harmonize(const FuncGraphPtr& graph, const ArgsPairList& args_spec_list);
|
ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
|
||||||
|
|
||||||
MultitypeFuncGraphPtr fn_leaf_;
|
MultitypeFuncGraphPtr fn_leaf_;
|
||||||
bool broadcast_;
|
bool broadcast_;
|
||||||
|
@ -113,7 +113,7 @@ using HyperMapPtr = std::shared_ptr<HyperMap>;
|
||||||
|
|
||||||
class HyperMapPy : public HyperMap {
|
class HyperMapPy : public HyperMap {
|
||||||
public:
|
public:
|
||||||
explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr) : HyperMap(fn_leaf) {}
|
explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {}
|
||||||
~HyperMapPy() override = default;
|
~HyperMapPy() override = default;
|
||||||
MS_DECLARE_PARENT(HyperMapPy, HyperMap)
|
MS_DECLARE_PARENT(HyperMapPy, HyperMap)
|
||||||
};
|
};
|
||||||
|
@ -123,56 +123,56 @@ extern ValuePtr kCompositeHyperMap;
|
||||||
|
|
||||||
class Tail : public MetaFuncGraph {
|
class Tail : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit Tail(const std::string& name) : MetaFuncGraph(name) {}
|
explicit Tail(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
~Tail() override = default;
|
~Tail() override = default;
|
||||||
MS_DECLARE_PARENT(Tail, MetaFuncGraph)
|
MS_DECLARE_PARENT(Tail, MetaFuncGraph)
|
||||||
|
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple);
|
FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple);
|
||||||
FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr& a_list);
|
FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list);
|
||||||
|
|
||||||
friend bool operator==(const Tail& lhs, const Tail& rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
};
|
};
|
||||||
using TailPtr = std::shared_ptr<Tail>;
|
using TailPtr = std::shared_ptr<Tail>;
|
||||||
|
|
||||||
class MakeTupleGradient : public MetaFuncGraph {
|
class MakeTupleGradient : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit MakeTupleGradient(const std::string& name) : MetaFuncGraph(name) {}
|
explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
~MakeTupleGradient() override = default;
|
~MakeTupleGradient() override = default;
|
||||||
MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
|
MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend bool operator==(const MakeTupleGradient& lhs, const MakeTupleGradient& rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
};
|
};
|
||||||
using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
|
using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
|
||||||
|
|
||||||
class GradOperation : public MetaFuncGraph {
|
class GradOperation : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit GradOperation(const std::string& name, bool get_all = false, bool get_by_list = false,
|
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
|
||||||
bool sens_param = false);
|
bool sens_param = false);
|
||||||
~GradOperation() override = default;
|
~GradOperation() override = default;
|
||||||
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
||||||
|
|
||||||
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams,
|
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
|
||||||
bool applyJ = false);
|
bool applyJ = false);
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
bool sens_param() const { return sens_param_; }
|
bool sens_param() const { return sens_param_; }
|
||||||
bool get_all_;
|
bool get_all_;
|
||||||
bool get_by_list_;
|
bool get_by_list_;
|
||||||
bool sens_param_;
|
bool sens_param_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
|
void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
|
||||||
ValueNodePtr opsTupleItem);
|
ValueNodePtr opsTupleItem);
|
||||||
};
|
};
|
||||||
using GradOperationPtr = std::shared_ptr<GradOperation>;
|
using GradOperationPtr = std::shared_ptr<GradOperation>;
|
||||||
|
|
||||||
class ListMap {
|
class ListMap {
|
||||||
public:
|
public:
|
||||||
explicit ListMap(const std::string& name) : name_(name) { cache_.clear(); }
|
explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
|
||||||
~ListMap() = default;
|
~ListMap() = default;
|
||||||
void MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gnext_ptr, const FuncGraphPtr& graph_ptr);
|
void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr);
|
||||||
void MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gcond_ptr, const FuncGraphPtr& graph_ptr);
|
void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr);
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list);
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string name_;
|
std::string name_;
|
||||||
|
@ -181,31 +181,31 @@ class ListMap {
|
||||||
|
|
||||||
class TupleAdd : public MetaFuncGraph {
|
class TupleAdd : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit TupleAdd(const std::string& name) : MetaFuncGraph(name) {}
|
explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
~TupleAdd() override = default;
|
~TupleAdd() override = default;
|
||||||
MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
|
MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend bool operator==(const TupleAdd& lhs, const TupleAdd& rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
};
|
};
|
||||||
using TupleAddPtr = std::shared_ptr<TupleAdd>;
|
using TupleAddPtr = std::shared_ptr<TupleAdd>;
|
||||||
|
|
||||||
class TupleSlice : public MetaFuncGraph {
|
class TupleSlice : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit TupleSlice(const std::string& name) : MetaFuncGraph(name) {}
|
explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
~TupleSlice() override = default;
|
~TupleSlice() override = default;
|
||||||
MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
|
MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend bool operator==(const TupleSlice& lhs, const TupleSlice& rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
};
|
};
|
||||||
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
|
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
|
||||||
|
|
||||||
class TensorSlice : public MetaFuncGraph {
|
class TensorSlice : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit TensorSlice(const std::string& name) : MetaFuncGraph(name) {}
|
explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
~TensorSlice() override = default;
|
~TensorSlice() override = default;
|
||||||
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend bool operator==(const TensorSlice& lhs, const TensorSlice& rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
};
|
};
|
||||||
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace prim {
|
||||||
namespace {
|
namespace {
|
||||||
using PatternListType = std::initializer_list<BaseRef>;
|
using PatternListType = std::initializer_list<BaseRef>;
|
||||||
|
|
||||||
const std::vector<Signature>& GetSignature(const ValuePtr& function) {
|
const std::vector<Signature> &GetSignature(const ValuePtr &function) {
|
||||||
static const auto empty = std::vector<Signature>();
|
static const auto empty = std::vector<Signature>();
|
||||||
if (function->isa<Primitive>()) {
|
if (function->isa<Primitive>()) {
|
||||||
return function->cast<PrimitivePtr>()->signatures();
|
return function->cast<PrimitivePtr>()->signatures();
|
||||||
|
@ -44,8 +44,8 @@ const std::vector<Signature>& GetSignature(const ValuePtr& function) {
|
||||||
return empty;
|
return empty;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& args_spec_list,
|
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
|
||||||
const std::vector<Signature>& signature, bool has_var, std::vector<AnfNodePtr>* op_inputs) {
|
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *op_inputs) {
|
||||||
std::size_t sig_size = signature.size();
|
std::size_t sig_size = signature.size();
|
||||||
auto positional_size = sig_size;
|
auto positional_size = sig_size;
|
||||||
if (has_var) {
|
if (has_var) {
|
||||||
|
@ -64,8 +64,8 @@ void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& arg
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
||||||
std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType>& dtypes,
|
std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType> &dtypes,
|
||||||
const abstract::AbstractBasePtrList& args_spec_list) {
|
const abstract::AbstractBasePtrList &args_spec_list) {
|
||||||
// record index for signature.dtypes of the same type
|
// record index for signature.dtypes of the same type
|
||||||
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
||||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
||||||
|
@ -89,7 +89,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& index : indexs) {
|
for (const auto &index : indexs) {
|
||||||
AbstractBasePtr arg_value = args_spec_list[index];
|
AbstractBasePtr arg_value = args_spec_list[index];
|
||||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||||
|
@ -104,7 +104,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
|
||||||
return dst_type;
|
return dst_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const FuncGraphPtr& graph) {
|
AnfNodePtr DoCast(const AnfNodePtr ¶m, const AnfNodePtr &source_param, const FuncGraphPtr &graph) {
|
||||||
// op and module import path
|
// op and module import path
|
||||||
auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional");
|
auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional");
|
||||||
MS_EXCEPTION_IF_NULL(prim_dtype);
|
MS_EXCEPTION_IF_NULL(prim_dtype);
|
||||||
|
@ -116,11 +116,11 @@ AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const
|
||||||
return NewCNode({cast_node, param, dtype_node}, graph);
|
return NewCNode({cast_node, param, dtype_node}, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DoAutoCast(const std::vector<Signature>& signature, const abstract::AbstractBasePtrList& args_spec_list,
|
void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list,
|
||||||
const FuncGraphPtr& graph, std::vector<AnfNodePtr>* op_inputs) {
|
const FuncGraphPtr &graph, std::vector<AnfNodePtr> *op_inputs) {
|
||||||
std::vector<SignatureEnumDType> dtypes;
|
std::vector<SignatureEnumDType> dtypes;
|
||||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||||
[](const Signature& sig) { return sig.dtype; });
|
[](const Signature &sig) { return sig.dtype; });
|
||||||
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
||||||
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
||||||
return;
|
return;
|
||||||
|
@ -143,10 +143,10 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function,
|
AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
|
||||||
const AbstractBasePtrList& args_spec_list, const std::vector<AnfNodePtr>& params_list) {
|
const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) {
|
||||||
// args: original inputs
|
// args: original inputs
|
||||||
auto& signature = GetSignature(function);
|
auto &signature = GetSignature(function);
|
||||||
std::size_t sig_size = signature.size();
|
std::size_t sig_size = signature.size();
|
||||||
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
|
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
|
||||||
if (sig_size > 0) {
|
if (sig_size > 0) {
|
||||||
|
@ -196,13 +196,13 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function,
|
AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
|
||||||
const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs) {
|
const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) {
|
||||||
auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs);
|
auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs);
|
||||||
return new_cnode;
|
return new_cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
|
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
|
||||||
|
|
||||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||||
|
|
|
@ -37,17 +37,17 @@ namespace mindspore {
|
||||||
namespace prim {
|
namespace prim {
|
||||||
class DoSignatureMetaFuncGraph : public MetaFuncGraph {
|
class DoSignatureMetaFuncGraph : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit DoSignatureMetaFuncGraph(const std::string& name, const ValuePtr& function)
|
explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function)
|
||||||
: MetaFuncGraph("S-" + name), function_(function) {}
|
: MetaFuncGraph("S-" + name), function_(function) {}
|
||||||
|
|
||||||
~DoSignatureMetaFuncGraph() override = default;
|
~DoSignatureMetaFuncGraph() override = default;
|
||||||
|
|
||||||
MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph)
|
MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph)
|
||||||
|
|
||||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override;
|
||||||
const ValuePtr function() const { return function_; }
|
const ValuePtr function() const { return function_; }
|
||||||
|
|
||||||
friend bool operator==(const DoSignatureMetaFuncGraph& lhs, const DoSignatureMetaFuncGraph& rhs) {
|
friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) {
|
||||||
return &lhs == &rhs;
|
return &lhs == &rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,8 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph {
|
||||||
};
|
};
|
||||||
using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>;
|
using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>;
|
||||||
|
|
||||||
AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function,
|
AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
|
||||||
const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs);
|
const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs);
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// namespace to support composite operators definition
|
// namespace to support composite operators definition
|
||||||
namespace prim {
|
namespace prim {
|
||||||
FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& args_list) {
|
FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||||
abstract::CheckArgsSize("ListAppend", args_list, 2);
|
abstract::CheckArgsSize("ListAppend", args_list, 2);
|
||||||
|
|
||||||
AbstractBasePtr arg0 = args_list[0];
|
AbstractBasePtr arg0 = args_list[0];
|
||||||
|
@ -52,9 +52,9 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList&
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) {
|
||||||
(void)py::class_<ListAppend, MetaFuncGraph, std::shared_ptr<ListAppend>>(*m, "ListAppend_")
|
(void)py::class_<ListAppend, MetaFuncGraph, std::shared_ptr<ListAppend>>(*m, "ListAppend_")
|
||||||
.def(py::init<std::string&>());
|
.def(py::init<std::string &>());
|
||||||
}));
|
}));
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -28,15 +28,15 @@ namespace mindspore {
|
||||||
namespace prim {
|
namespace prim {
|
||||||
class ListAppend : public MetaFuncGraph {
|
class ListAppend : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit ListAppend(const std::string& name) : MetaFuncGraph(name) {}
|
explicit ListAppend(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
~ListAppend() override = default;
|
~ListAppend() override = default;
|
||||||
MS_DECLARE_PARENT(ListAppend, MetaFuncGraph)
|
MS_DECLARE_PARENT(ListAppend, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& a_list) override;
|
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||||
friend std::ostream& operator<<(std::ostream& os, const ListAppend& list_append) {
|
friend std::ostream &operator<<(std::ostream &os, const ListAppend &list_append) {
|
||||||
os << list_append.name_;
|
os << list_append.name_;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
friend bool operator==(const ListAppend& lhs, const ListAppend& rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const ListAppend &lhs, const ListAppend &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
};
|
};
|
||||||
using ListAppendPtr = std::shared_ptr<ListAppend>;
|
using ListAppendPtr = std::shared_ptr<ListAppend>;
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
|
|
|
@ -40,7 +40,7 @@ using mindspore::abstract::AbstractKeywordArg;
|
||||||
using mindspore::abstract::AbstractTuple;
|
using mindspore::abstract::AbstractTuple;
|
||||||
using mindspore::abstract::AbstractTuplePtr;
|
using mindspore::abstract::AbstractTuplePtr;
|
||||||
|
|
||||||
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// slice a tensor
|
// slice a tensor
|
||||||
// args: tensor, slice or slice tuple
|
// args: tensor, slice or slice tuple
|
||||||
const std::string op_name = std::string("UnpackCall");
|
const std::string op_name = std::string("UnpackCall");
|
||||||
|
@ -70,7 +70,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
|
||||||
AnfNodePtr para_dict = ret_graph->add_parameter();
|
AnfNodePtr para_dict = ret_graph->add_parameter();
|
||||||
auto dict_elems = arg_dict->elements();
|
auto dict_elems = arg_dict->elements();
|
||||||
(void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
|
(void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
|
||||||
[ret_graph, para_dict](const AbstractAttribute& item) {
|
[ret_graph, para_dict](const AbstractAttribute &item) {
|
||||||
auto dict_get_item = ret_graph->NewCNode(
|
auto dict_get_item = ret_graph->NewCNode(
|
||||||
{NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)});
|
{NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)});
|
||||||
return ret_graph->NewCNode(
|
return ret_graph->NewCNode(
|
||||||
|
@ -85,9 +85,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
|
||||||
return ret_graph;
|
return ret_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) {
|
||||||
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
|
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
|
||||||
.def(py::init<std::string&>());
|
.def(py::init<std::string &>());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
|
|
|
@ -40,11 +40,11 @@ namespace prim {
|
||||||
// and generate positional parameters and key-value pairs for function.
|
// and generate positional parameters and key-value pairs for function.
|
||||||
class UnpackCall : public MetaFuncGraph {
|
class UnpackCall : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {}
|
explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
~UnpackCall() override = default;
|
~UnpackCall() override = default;
|
||||||
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
|
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
};
|
};
|
||||||
using UnpackCallPtr = std::shared_ptr<UnpackCall>;
|
using UnpackCallPtr = std::shared_ptr<UnpackCall>;
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ namespace prim {
|
||||||
using mindspore::abstract::AbstractBase;
|
using mindspore::abstract::AbstractBase;
|
||||||
using mindspore::abstract::AbstractTuple;
|
using mindspore::abstract::AbstractTuple;
|
||||||
|
|
||||||
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// zip operation:
|
// zip operation:
|
||||||
// input: tuple arguments
|
// input: tuple arguments
|
||||||
// output: tuple of items of input iterated on every input
|
// output: tuple of items of input iterated on every input
|
||||||
|
@ -44,7 +44,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
|
||||||
MS_LOG(EXCEPTION) << "zip arguments input should not be empty";
|
MS_LOG(EXCEPTION) << "zip arguments input should not be empty";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr& abs) -> bool {
|
auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
return abs->isa<AbstractTuple>();
|
return abs->isa<AbstractTuple>();
|
||||||
});
|
});
|
||||||
|
@ -53,7 +53,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
|
||||||
}
|
}
|
||||||
|
|
||||||
auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(),
|
auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(),
|
||||||
[](const AbstractBasePtr& x, const AbstractBasePtr& y) {
|
[](const AbstractBasePtr &x, const AbstractBasePtr &y) {
|
||||||
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
|
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
|
||||||
});
|
});
|
||||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||||
|
@ -81,10 +81,10 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
|
||||||
return ret_graph;
|
return ret_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module* m) {
|
REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) {
|
||||||
(void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m,
|
(void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m,
|
||||||
"ZipOperation_")
|
"ZipOperation_")
|
||||||
.def(py::init<std::string&>());
|
.def(py::init<std::string &>());
|
||||||
}));
|
}));
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -42,15 +42,15 @@ using AbstractTuplePtr = abstract::AbstractTuplePtr;
|
||||||
|
|
||||||
class ZipOperation : public MetaFuncGraph {
|
class ZipOperation : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit ZipOperation(const std::string& name) : MetaFuncGraph(name) {}
|
explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
~ZipOperation() override = default;
|
~ZipOperation() override = default;
|
||||||
MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph)
|
MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend std::ostream& operator<<(std::ostream& os, const ZipOperation& op) {
|
friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) {
|
||||||
os << op.name_;
|
os << op.name_;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
friend bool operator==(const ZipOperation& lhs, const ZipOperation& rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
};
|
};
|
||||||
using ZipOperationPtr = std::shared_ptr<ZipOperation>;
|
using ZipOperationPtr = std::shared_ptr<ZipOperation>;
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
|
|
|
@ -238,7 +238,7 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary
|
||||||
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
||||||
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
||||||
|
|
||||||
ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) {
|
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) {
|
||||||
py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
|
py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
|
||||||
ValuePtr node = nullptr;
|
ValuePtr node = nullptr;
|
||||||
bool succ = parse::ConvertData(obj, &node);
|
bool succ = parse::ConvertData(obj, &node);
|
||||||
|
|
|
@ -26,8 +26,8 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// namespace to support primitive operators
|
// namespace to support primitive operators
|
||||||
namespace prim {
|
namespace prim {
|
||||||
ValuePtr GetPythonOps(const std::string& op_name,
|
ValuePtr GetPythonOps(const std::string &op_name,
|
||||||
const std::string& module_name = "mindspore._extends.parse.standard_method");
|
const std::string &module_name = "mindspore._extends.parse.standard_method");
|
||||||
|
|
||||||
// Arithmetic
|
// Arithmetic
|
||||||
extern const PrimitivePtr kPrimScalarAdd;
|
extern const PrimitivePtr kPrimScalarAdd;
|
||||||
|
@ -241,7 +241,7 @@ extern const PrimitivePtr kPrimVirtualDataset;
|
||||||
|
|
||||||
class DoSignaturePrimitive : public Primitive {
|
class DoSignaturePrimitive : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit DoSignaturePrimitive(const std::string& name, const ValuePtr& function)
|
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||||
: Primitive("S-Prim-" + name), function_(function) {}
|
: Primitive("S-Prim-" + name), function_(function) {}
|
||||||
|
|
||||||
~DoSignaturePrimitive() override = default;
|
~DoSignaturePrimitive() override = default;
|
||||||
|
@ -257,7 +257,7 @@ using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
|
||||||
|
|
||||||
class UnpackGraphPrimitive : public Primitive {
|
class UnpackGraphPrimitive : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args)
|
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)
|
||||||
: Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {}
|
: Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {}
|
||||||
~UnpackGraphPrimitive() override = default;
|
~UnpackGraphPrimitive() override = default;
|
||||||
MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive)
|
MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive)
|
||||||
|
|
|
@ -54,7 +54,7 @@ PrimToFunction::PrimToFunction()
|
||||||
{"scalar_sub", kPrimTypeTwoArgs},
|
{"scalar_sub", kPrimTypeTwoArgs},
|
||||||
{"scalar_floordiv", kPrimTypeTwoArgs}}) {}
|
{"scalar_floordiv", kPrimTypeTwoArgs}}) {}
|
||||||
|
|
||||||
bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const {
|
bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const {
|
||||||
bool result = false;
|
bool result = false;
|
||||||
|
|
||||||
if (func != nullptr) {
|
if (func != nullptr) {
|
||||||
|
@ -79,7 +79,7 @@ bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const fu
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
int PrimToFunction::GetPrimType(const PrimitivePtr& prim) const {
|
int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
int prim_type = static_cast<int>(kPrimTypeUnknown);
|
int prim_type = static_cast<int>(kPrimTypeUnknown);
|
||||||
|
|
||||||
|
|
|
@ -41,21 +41,21 @@ class PrimToFunction;
|
||||||
class PrimToFunction {
|
class PrimToFunction {
|
||||||
public:
|
public:
|
||||||
// Return a thread-safe singleton instance
|
// Return a thread-safe singleton instance
|
||||||
static PrimToFunction& GetInstance() {
|
static PrimToFunction &GetInstance() {
|
||||||
static PrimToFunction instance;
|
static PrimToFunction instance;
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
PrimToFunction(const PrimToFunction&) = delete;
|
PrimToFunction(const PrimToFunction &) = delete;
|
||||||
PrimToFunction& operator=(const PrimToFunction&) = delete;
|
PrimToFunction &operator=(const PrimToFunction &) = delete;
|
||||||
~PrimToFunction() = default;
|
~PrimToFunction() = default;
|
||||||
|
|
||||||
// Get the args and return value for a primitive instance.
|
// Get the args and return value for a primitive instance.
|
||||||
bool GetFunction(const PrimitivePtr& prim, FunctionPtr* func) const;
|
bool GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PrimToFunction();
|
PrimToFunction();
|
||||||
// Get the number of primitive arguments
|
// Get the number of primitive arguments
|
||||||
int GetPrimType(const PrimitivePtr& prim) const;
|
int GetPrimType(const PrimitivePtr &prim) const;
|
||||||
const std::unordered_map<std::string, int> prim_func_type_map_;
|
const std::unordered_map<std::string, int> prim_func_type_map_;
|
||||||
};
|
};
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller)
|
Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller)
|
||||||
: primal_(primal), caller_(caller), dout_(nullptr) {
|
: primal_(primal), caller_(caller), dout_(nullptr) {
|
||||||
if (k != nullptr) {
|
if (k != nullptr) {
|
||||||
k_ = k;
|
k_ = k;
|
||||||
|
@ -43,13 +43,13 @@ Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphP
|
||||||
|
|
||||||
AnfNodePtr Adjoint::k() { return k_; }
|
AnfNodePtr Adjoint::k() { return k_; }
|
||||||
|
|
||||||
void Adjoint::RegisterKUser(const CNodePtr& user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); }
|
void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); }
|
||||||
|
|
||||||
void Adjoint::UpdateK(const AnfNodePtr& new_k) {
|
void Adjoint::UpdateK(const AnfNodePtr &new_k) {
|
||||||
MS_EXCEPTION_IF_NULL(new_k);
|
MS_EXCEPTION_IF_NULL(new_k);
|
||||||
MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString();
|
MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString();
|
||||||
// In recursive case, it needs update.
|
// In recursive case, it needs update.
|
||||||
for (auto& user : k_user_) {
|
for (auto &user : k_user_) {
|
||||||
MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k"
|
MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k"
|
||||||
<< new_k->ToString();
|
<< new_k->ToString();
|
||||||
if (user.first->input(user.second) != k_) {
|
if (user.first->input(user.second) != k_) {
|
||||||
|
@ -65,11 +65,11 @@ AnfNodePtr Adjoint::primal() { return primal_; }
|
||||||
|
|
||||||
AnfNodePtr Adjoint::dout() { return dout_hole_; }
|
AnfNodePtr Adjoint::dout() { return dout_hole_; }
|
||||||
|
|
||||||
void Adjoint::RegisterDoutUser(const CNodePtr& user, size_t index) {
|
void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) {
|
||||||
dout_user_.emplace_back(std::make_pair(user, index));
|
dout_user_.emplace_back(std::make_pair(user, index));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) {
|
void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) {
|
||||||
if (dout_ != nullptr) {
|
if (dout_ != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString();
|
MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString();
|
||||||
auto add = prim::GetPythonOps("hyper_add");
|
auto add = prim::GetPythonOps("hyper_add");
|
||||||
|
@ -81,7 +81,7 @@ void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) {
|
||||||
|
|
||||||
void Adjoint::CallDoutHole() {
|
void Adjoint::CallDoutHole() {
|
||||||
if (dout_ != nullptr) {
|
if (dout_ != nullptr) {
|
||||||
for (auto& user : dout_user_) {
|
for (auto &user : dout_user_) {
|
||||||
MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout "
|
MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout "
|
||||||
<< dout_->ToString();
|
<< dout_->ToString();
|
||||||
if (user.first->input(user.second) != dout_hole_) {
|
if (user.first->input(user.second) != dout_hole_) {
|
||||||
|
|
|
@ -28,15 +28,15 @@ namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
class Adjoint {
|
class Adjoint {
|
||||||
public:
|
public:
|
||||||
Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller);
|
Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller);
|
||||||
~Adjoint() = default;
|
~Adjoint() = default;
|
||||||
AnfNodePtr primal();
|
AnfNodePtr primal();
|
||||||
AnfNodePtr k();
|
AnfNodePtr k();
|
||||||
void UpdateK(const AnfNodePtr& k);
|
void UpdateK(const AnfNodePtr &k);
|
||||||
void RegisterKUser(const CNodePtr& user, size_t index);
|
void RegisterKUser(const CNodePtr &user, size_t index);
|
||||||
AnfNodePtr dout();
|
AnfNodePtr dout();
|
||||||
void AccumulateDout(const AnfNodePtr& dout_factor);
|
void AccumulateDout(const AnfNodePtr &dout_factor);
|
||||||
void RegisterDoutUser(const CNodePtr& user, size_t index);
|
void RegisterDoutUser(const CNodePtr &user, size_t index);
|
||||||
void CallDoutHole();
|
void CallDoutHole();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -36,7 +36,7 @@ using mindspore::abstract::AbstractList;
|
||||||
using mindspore::abstract::AbstractScalar;
|
using mindspore::abstract::AbstractScalar;
|
||||||
using mindspore::abstract::AbstractTuple;
|
using mindspore::abstract::AbstractTuple;
|
||||||
|
|
||||||
static AbstractBasePtr Reabs(const AbstractBasePtr& t) {
|
static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
|
||||||
if (t == nullptr) {
|
if (t == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -47,14 +47,14 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) {
|
||||||
AbstractBasePtrList baselist;
|
AbstractBasePtrList baselist;
|
||||||
auto attributes = abs_class->attributes();
|
auto attributes = abs_class->attributes();
|
||||||
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
|
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
|
||||||
[](const AbstractAttribute& item) { return item.second; });
|
[](const AbstractAttribute &item) { return item.second; });
|
||||||
res = std::make_shared<AbstractTuple>(baselist);
|
res = std::make_shared<AbstractTuple>(baselist);
|
||||||
} else if (t->isa<AbstractDictionary>()) {
|
} else if (t->isa<AbstractDictionary>()) {
|
||||||
auto abs_dict = dyn_cast<AbstractDictionary>(t);
|
auto abs_dict = dyn_cast<AbstractDictionary>(t);
|
||||||
AbstractBasePtrList baselist;
|
AbstractBasePtrList baselist;
|
||||||
auto elements = abs_dict->elements();
|
auto elements = abs_dict->elements();
|
||||||
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
|
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
|
||||||
[](const AbstractAttribute& item) { return item.second; });
|
[](const AbstractAttribute &item) { return item.second; });
|
||||||
res = std::make_shared<AbstractTuple>(baselist);
|
res = std::make_shared<AbstractTuple>(baselist);
|
||||||
} else if (t->isa<AbstractList>()) {
|
} else if (t->isa<AbstractList>()) {
|
||||||
auto abs_dict = dyn_cast<AbstractList>(t);
|
auto abs_dict = dyn_cast<AbstractList>(t);
|
||||||
|
@ -63,11 +63,11 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) {
|
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
const auto& inputs = node->inputs();
|
const auto &inputs = node->inputs();
|
||||||
// Inputs should be [getattr, data, attribute]
|
// Inputs should be [getattr, data, attribute]
|
||||||
MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs.");
|
MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs.");
|
||||||
|
|
||||||
|
@ -86,9 +86,9 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) {
|
||||||
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
|
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
|
||||||
|
|
||||||
auto ct = dyn_cast<AbstractClass>(dt);
|
auto ct = dyn_cast<AbstractClass>(dt);
|
||||||
const auto& cmap = ct->attributes();
|
const auto &cmap = ct->attributes();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
for (auto& item : cmap) {
|
for (auto &item : cmap) {
|
||||||
if (cons_is_str && item.first == cons_str) {
|
if (cons_is_str && item.first == cons_str) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -102,12 +102,12 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) {
|
||||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
|
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) {
|
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
// Inputs should be [dict_getitem, dict, item]
|
// Inputs should be [dict_getitem, dict, item]
|
||||||
const auto& inputs = node->inputs();
|
const auto &inputs = node->inputs();
|
||||||
MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs.");
|
MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs.");
|
||||||
|
|
||||||
AnfNodePtr data = inputs[1];
|
AnfNodePtr data = inputs[1];
|
||||||
|
@ -124,9 +124,9 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) {
|
||||||
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
|
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
|
||||||
|
|
||||||
auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
|
auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
|
||||||
const auto& cmap = ct->elements();
|
const auto &cmap = ct->elements();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
for (auto& item : cmap) {
|
for (auto &item : cmap) {
|
||||||
if (cons_is_str && item.first == cons_str) {
|
if (cons_is_str && item.first == cons_str) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -139,7 +139,7 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) {
|
||||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
|
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) {
|
AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
|
@ -150,11 +150,11 @@ AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) {
|
||||||
return node->func_graph()->NewCNode(inputs);
|
return node->func_graph()->NewCNode(inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ErasePartialNode(const CNodePtr& node) {
|
AnfNodePtr ErasePartialNode(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
const auto& inputs = node->inputs();
|
const auto &inputs = node->inputs();
|
||||||
// Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
|
// Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
|
||||||
MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs.");
|
MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs.");
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ AnfNodePtr ErasePartialNode(const CNodePtr& node) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) {
|
AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
|
@ -189,11 +189,11 @@ AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) {
|
||||||
return node->func_graph()->NewCNode(inputs);
|
return node->func_graph()->NewCNode(inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) {
|
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
const auto& inputs = node->inputs();
|
const auto &inputs = node->inputs();
|
||||||
// Inputs should be [list_getitem, list, item]
|
// Inputs should be [list_getitem, list, item]
|
||||||
if (inputs.size() < 3) {
|
if (inputs.size() < 3) {
|
||||||
MS_LOG(EXCEPTION) << "Node's input number < 3.";
|
MS_LOG(EXCEPTION) << "Node's input number < 3.";
|
||||||
|
@ -208,11 +208,11 @@ AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) {
|
||||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
|
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) {
|
AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
const auto& inputs = node->inputs();
|
const auto &inputs = node->inputs();
|
||||||
// Inputs should be [list_setitem, list, index, item]
|
// Inputs should be [list_setitem, list, index, item]
|
||||||
if (inputs.size() < 4) {
|
if (inputs.size() < 4) {
|
||||||
MS_LOG(EXCEPTION) << "Node's input number < 4.";
|
MS_LOG(EXCEPTION) << "Node's input number < 4.";
|
||||||
|
@ -225,36 +225,36 @@ AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) {
|
||||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
|
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr EraseMakeDictNode(const CNodePtr& node) {
|
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
const auto& inputs = node->inputs();
|
const auto &inputs = node->inputs();
|
||||||
MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs");
|
MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs");
|
||||||
return inputs[2];
|
return inputs[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr& node) {
|
AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
const auto& inputs = node->inputs();
|
const auto &inputs = node->inputs();
|
||||||
// Inputs should be [make_keyword_arg, key, value]
|
// Inputs should be [make_keyword_arg, key, value]
|
||||||
MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs");
|
MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs");
|
||||||
return inputs[2];
|
return inputs[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr EraseExtractKeywordArg(const CNodePtr& node) {
|
AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
const auto& inputs = node->inputs();
|
const auto &inputs = node->inputs();
|
||||||
// Inputs should be [extract_keyword_arg, arg, key]
|
// Inputs should be [extract_keyword_arg, arg, key]
|
||||||
MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs");
|
MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs");
|
||||||
return inputs[2];
|
return inputs[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int depth) {
|
ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) {
|
||||||
const int DEPTH_MAX = 5;
|
const int DEPTH_MAX = 5;
|
||||||
if (depth > DEPTH_MAX) {
|
if (depth > DEPTH_MAX) {
|
||||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels.";
|
MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels.";
|
||||||
}
|
}
|
||||||
std::vector<ValuePtr> elements;
|
std::vector<ValuePtr> elements;
|
||||||
for (const auto& it : value_list->value()) {
|
for (const auto &it : value_list->value()) {
|
||||||
ValuePtr value = nullptr;
|
ValuePtr value = nullptr;
|
||||||
if (it->isa<ValueList>()) {
|
if (it->isa<ValueList>()) {
|
||||||
value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
|
value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
|
||||||
|
@ -266,7 +266,7 @@ ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int d
|
||||||
return std::make_shared<ValueTuple>(elements);
|
return std::make_shared<ValueTuple>(elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) {
|
AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
ValuePtr value = node->value();
|
ValuePtr value = node->value();
|
||||||
auto value_list = value->cast<ValueListPtr>();
|
auto value_list = value->cast<ValueListPtr>();
|
||||||
|
@ -278,13 +278,13 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) {
|
||||||
// Convert class to Tuple
|
// Convert class to Tuple
|
||||||
// Convert getattr to getitem
|
// Convert getattr to getitem
|
||||||
// Convert make_record to make_tuple
|
// Convert make_record to make_tuple
|
||||||
void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) {
|
void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
manager->AddFuncGraph(root);
|
manager->AddFuncGraph(root);
|
||||||
|
|
||||||
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
||||||
AnfNodeSet all_node = manager->all_nodes();
|
AnfNodeSet all_node = manager->all_nodes();
|
||||||
for (auto& node : all_node) {
|
for (auto &node : all_node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
AnfNodePtr new_node = nullptr;
|
AnfNodePtr new_node = nullptr;
|
||||||
|
@ -320,20 +320,20 @@ void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr&
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& node : manager->all_nodes()) {
|
for (auto &node : manager->all_nodes()) {
|
||||||
auto ret = Reabs(node->abstract());
|
auto ret = Reabs(node->abstract());
|
||||||
node->set_abstract(ret);
|
node->set_abstract(ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// expand tuples in graph parameters
|
// expand tuples in graph parameters
|
||||||
static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, const FuncGraphPtr& func_graph,
|
static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
|
||||||
const std::vector<AnfNodePtr>& params) {
|
const std::vector<AnfNodePtr> ¶ms) {
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
|
||||||
std::vector<AnfNodePtr> new_params;
|
std::vector<AnfNodePtr> new_params;
|
||||||
for (const auto& param : params) {
|
for (const auto ¶m : params) {
|
||||||
MS_EXCEPTION_IF_NULL(param);
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
auto param_abs = param->abstract();
|
auto param_abs = param->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(param_abs);
|
MS_EXCEPTION_IF_NULL(param_abs);
|
||||||
|
@ -350,7 +350,7 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con
|
||||||
std::vector<AnfNodePtr> new_param;
|
std::vector<AnfNodePtr> new_param;
|
||||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||||
auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
|
auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
|
||||||
for (auto& elem : abs_tuple->elements()) {
|
for (auto &elem : abs_tuple->elements()) {
|
||||||
auto np = std::make_shared<Parameter>(func_graph);
|
auto np = std::make_shared<Parameter>(func_graph);
|
||||||
np->set_abstract(elem);
|
np->set_abstract(elem);
|
||||||
new_param.emplace_back(np);
|
new_param.emplace_back(np);
|
||||||
|
@ -366,11 +366,11 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con
|
||||||
}
|
}
|
||||||
|
|
||||||
// expand tuples in graph applies
|
// expand tuples in graph applies
|
||||||
static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const std::vector<AnfNodePtr>& inputs) {
|
static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
|
||||||
std::vector<AnfNodePtr> new_inputs;
|
std::vector<AnfNodePtr> new_inputs;
|
||||||
for (const auto& input : inputs) {
|
for (const auto &input : inputs) {
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
|
||||||
auto input_abs = input->abstract();
|
auto input_abs = input->abstract();
|
||||||
|
@ -391,7 +391,7 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
std::vector<AnfNodePtr> new_input;
|
std::vector<AnfNodePtr> new_input;
|
||||||
auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
|
auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
|
||||||
for (auto& elem : abs_tuple->elements()) {
|
for (auto &elem : abs_tuple->elements()) {
|
||||||
auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
|
auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
|
||||||
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx));
|
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx));
|
||||||
c_node->input(2)->set_abstract(aptr);
|
c_node->input(2)->set_abstract(aptr);
|
||||||
|
@ -416,19 +416,19 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st
|
||||||
// tuples in Graph's parameters: AbstractTuple (a, b, c) -->
|
// tuples in Graph's parameters: AbstractTuple (a, b, c) -->
|
||||||
// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
|
// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
|
||||||
// cppcheck-suppress unusedFunction
|
// cppcheck-suppress unusedFunction
|
||||||
void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) {
|
void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
manager->AddFuncGraph(root);
|
manager->AddFuncGraph(root);
|
||||||
|
|
||||||
// NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
// NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
||||||
AnfNodeSet all_node = manager->all_nodes();
|
AnfNodeSet all_node = manager->all_nodes();
|
||||||
for (auto& node : all_node) {
|
for (auto &node : all_node) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
if (cnode == nullptr) {
|
if (cnode == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& inputs = cnode->inputs();
|
const auto &inputs = cnode->inputs();
|
||||||
|
|
||||||
// Bypass the first input in inputs as it's fn.
|
// Bypass the first input in inputs as it's fn.
|
||||||
if (!IsValueNode<Primitive>(inputs[0])) {
|
if (!IsValueNode<Primitive>(inputs[0])) {
|
||||||
|
@ -466,7 +466,7 @@ void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) {
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphSet all_graph = manager->func_graphs();
|
FuncGraphSet all_graph = manager->func_graphs();
|
||||||
for (auto& func_graph : all_graph) {
|
for (auto &func_graph : all_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
|
auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
|
||||||
manager->SetParameters(func_graph, expand_p);
|
manager->SetParameters(func_graph, expand_p);
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
// Automatically adding control depend based on effect order and side effect analysis.
|
// Automatically adding control depend based on effect order and side effect analysis.
|
||||||
void AddControlDepend(const FuncGraphPtr& graph);
|
void AddControlDepend(const FuncGraphPtr &graph);
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_
|
||||||
|
|
|
@ -44,7 +44,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func
|
||||||
nodes.push_back(func_node);
|
nodes.push_back(func_node);
|
||||||
// {unpackcall, {GradOperation, ...}, args...}
|
// {unpackcall, {GradOperation, ...}, args...}
|
||||||
std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes),
|
std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes),
|
||||||
[](const AnfNodePtr& node) { return node; });
|
[](const AnfNodePtr &node) { return node; });
|
||||||
unpack_graph_node = func_graph->NewCNode(nodes);
|
unpack_graph_node = func_graph->NewCNode(nodes);
|
||||||
} else {
|
} else {
|
||||||
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false);
|
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false);
|
||||||
|
@ -52,14 +52,14 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func
|
||||||
nodes.push_back(func_node);
|
nodes.push_back(func_node);
|
||||||
// {{GradOperation, ...}, args...}
|
// {{GradOperation, ...}, args...}
|
||||||
std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes),
|
std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes),
|
||||||
[](const AnfNodePtr& node) { return node; });
|
[](const AnfNodePtr &node) { return node; });
|
||||||
unpack_graph_node = func_graph->NewCNode(nodes);
|
unpack_graph_node = func_graph->NewCNode(nodes);
|
||||||
}
|
}
|
||||||
return unpack_graph_node;
|
return unpack_graph_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
// get metagraph of value node
|
// get metagraph of value node
|
||||||
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) {
|
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
|
||||||
ValuePtr value;
|
ValuePtr value;
|
||||||
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
|
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
|
||||||
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
|
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||||
|
@ -73,7 +73,7 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if node is a specific metafuncgraph op
|
// check if node is a specific metafuncgraph op
|
||||||
bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) {
|
bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) {
|
||||||
if (node != nullptr) {
|
if (node != nullptr) {
|
||||||
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
||||||
if (meta_func_graph_ptr == nullptr) {
|
if (meta_func_graph_ptr == nullptr) {
|
||||||
|
@ -89,7 +89,7 @@ bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_gr
|
||||||
|
|
||||||
// {{GradOperation, g, w}, Ys}
|
// {{GradOperation, g, w}, Ys}
|
||||||
// {UnPackCall, {GradOperation, g, w}, Ys}
|
// {UnPackCall, {GradOperation, g, w}, Ys}
|
||||||
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) {
|
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,20 +31,20 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
/* namespace to support opt */
|
/* namespace to support opt */
|
||||||
namespace opt {
|
namespace opt {
|
||||||
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim,
|
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
|
||||||
const RenormAction& renorm_action) {
|
const RenormAction &renorm_action) {
|
||||||
auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); };
|
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
|
||||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
||||||
}
|
}
|
||||||
|
|
||||||
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
|
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
|
||||||
const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) {
|
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
|
||||||
auto fn = [prims](const AnfNodePtr& node) -> bool {
|
auto fn = [prims](const AnfNodePtr &node) -> bool {
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& prim : prims) {
|
for (auto &prim : prims) {
|
||||||
if (IsPrimitiveCNode(node, prim)) {
|
if (IsPrimitiveCNode(node, prim)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -55,12 +55,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::
|
||||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
||||||
}
|
}
|
||||||
|
|
||||||
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
|
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
|
||||||
const PredicateFuncType& predicate, const RenormAction& renorm_action) {
|
const PredicateFuncType &predicate, const RenormAction &renorm_action) {
|
||||||
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
|
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const {
|
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const {
|
||||||
#ifdef ENABLE_PROFILE
|
#ifdef ENABLE_PROFILE
|
||||||
double t = GetTime();
|
double t = GetTime();
|
||||||
#endif
|
#endif
|
||||||
|
@ -88,8 +88,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNodePtr& root_node,
|
bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
|
||||||
const SubstitutionPtr& transform) const {
|
const SubstitutionPtr &transform) const {
|
||||||
FuncGraphManagerPtr manager = optimizer->manager();
|
FuncGraphManagerPtr manager = optimizer->manager();
|
||||||
std::unordered_set<AnfNodePtr> seen_node;
|
std::unordered_set<AnfNodePtr> seen_node;
|
||||||
std::deque<AnfNodePtr> todo{root_node};
|
std::deque<AnfNodePtr> todo{root_node};
|
||||||
|
@ -131,13 +131,13 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
auto& inputs = node->cast<CNodePtr>()->inputs();
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
|
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& node_users = manager->node_users();
|
auto &node_users = manager->node_users();
|
||||||
if (change && node_users.find(node) != node_users.end()) {
|
if (change && node_users.find(node) != node_users.end()) {
|
||||||
for (auto& use : node_users[node]) {
|
for (auto &use : node_users[node]) {
|
||||||
auto use_node = use.first;
|
auto use_node = use.first;
|
||||||
todo.push_back(use_node);
|
todo.push_back(use_node);
|
||||||
if (seen_node.find(use_node) != seen_node.end()) {
|
if (seen_node.find(use_node) != seen_node.end()) {
|
||||||
|
@ -152,7 +152,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo
|
||||||
return changes;
|
return changes;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const OptimizerPtr& optimizer) const {
|
bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
|
||||||
MS_EXCEPTION_IF_NULL(optimizer);
|
MS_EXCEPTION_IF_NULL(optimizer);
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
FuncGraphManagerPtr manager = optimizer->manager();
|
FuncGraphManagerPtr manager = optimizer->manager();
|
||||||
|
@ -163,7 +163,7 @@ bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const Optimize
|
||||||
|
|
||||||
do {
|
do {
|
||||||
loop = false;
|
loop = false;
|
||||||
for (auto const& transform : list_) {
|
for (auto const &transform : list_) {
|
||||||
auto change = ApplyTransform(optimizer, func_graph->output(), transform);
|
auto change = ApplyTransform(optimizer, func_graph->output(), transform);
|
||||||
changes = changes || change;
|
changes = changes || change;
|
||||||
loop = loop || change;
|
loop = loop || change;
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t recursive_times = 0) {
|
std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) {
|
||||||
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
||||||
MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is "
|
MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is "
|
||||||
<< MAX_RECURSIVE_CALL_TIMES;
|
<< MAX_RECURSIVE_CALL_TIMES;
|
||||||
|
@ -39,7 +39,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
auto node_set = manager->node_users()[para];
|
auto node_set = manager->node_users()[para];
|
||||||
std::unordered_set<CNodePtr> cnode_set;
|
std::unordered_set<CNodePtr> cnode_set;
|
||||||
for (auto& node_pair : node_set) {
|
for (auto &node_pair : node_set) {
|
||||||
auto cnode = node_pair.first->cast<CNodePtr>();
|
auto cnode = node_pair.first->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||||
|
@ -54,7 +54,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
|
||||||
(void)cnode_set.emplace(cnode);
|
(void)cnode_set.emplace(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
|
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
|
||||||
for (auto& cnode_sub : cnode_set_sub) {
|
for (auto &cnode_sub : cnode_set_sub) {
|
||||||
(void)cnode_set.emplace(cnode_sub);
|
(void)cnode_set.emplace(cnode_sub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,8 +63,8 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllreduceFusion::AddNodeToGraph() {
|
Status AllreduceFusion::AddNodeToGraph() {
|
||||||
const auto& parameters = root_graph_->parameters();
|
const auto ¶meters = root_graph_->parameters();
|
||||||
for (auto& parameter : parameters) {
|
for (auto ¶meter : parameters) {
|
||||||
if (!ParameterRequireGrad(parameter)) {
|
if (!ParameterRequireGrad(parameter)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ Status AllreduceFusion::AddNodeToGraph() {
|
||||||
if (cnode_set.empty()) {
|
if (cnode_set.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (auto& cnode : cnode_set) {
|
for (auto &cnode : cnode_set) {
|
||||||
MS_LOG(DEBUG) << "AddNode " << cnode->DebugString();
|
MS_LOG(DEBUG) << "AddNode " << cnode->DebugString();
|
||||||
if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) {
|
if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString();
|
MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString();
|
||||||
|
@ -83,7 +83,7 @@ Status AllreduceFusion::AddNodeToGraph() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursive_times) const {
|
CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const {
|
||||||
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
||||||
MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is "
|
MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is "
|
||||||
<< MAX_RECURSIVE_CALL_TIMES;
|
<< MAX_RECURSIVE_CALL_TIMES;
|
||||||
|
@ -110,30 +110,30 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursi
|
||||||
return cnode_dist;
|
return cnode_dist;
|
||||||
} else {
|
} else {
|
||||||
auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1);
|
auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1);
|
||||||
for (auto& ele : cnode_dist_next) {
|
for (auto &ele : cnode_dist_next) {
|
||||||
cnode_dist[ele.first] = cost + ele.second;
|
cnode_dist[ele.first] = cost + ele.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto cnode_dist_next = FindNextCNodes(cnode);
|
auto cnode_dist_next = FindNextCNodes(cnode);
|
||||||
for (auto& ele : cnode_dist_next) {
|
for (auto &ele : cnode_dist_next) {
|
||||||
cnode_dist[ele.first] = ele.second;
|
cnode_dist[ele.first] = ele.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cnode_dist;
|
return cnode_dist;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recursive_times) const {
|
CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const {
|
||||||
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
||||||
MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is "
|
MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is "
|
||||||
<< MAX_RECURSIVE_CALL_TIMES;
|
<< MAX_RECURSIVE_CALL_TIMES;
|
||||||
}
|
}
|
||||||
const auto& from_inputs = from->inputs();
|
const auto &from_inputs = from->inputs();
|
||||||
std::unordered_map<CNodePtr, double> dist_map;
|
std::unordered_map<CNodePtr, double> dist_map;
|
||||||
MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs";
|
MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs";
|
||||||
for (auto& input_node : from_inputs) {
|
for (auto &input_node : from_inputs) {
|
||||||
auto cnode_dist = FindCNode(input_node, recursive_times + 1);
|
auto cnode_dist = FindCNode(input_node, recursive_times + 1);
|
||||||
for (auto& ele : cnode_dist) {
|
for (auto &ele : cnode_dist) {
|
||||||
(void)dist_map.emplace(ele);
|
(void)dist_map.emplace(ele);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -142,11 +142,11 @@ CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recu
|
||||||
|
|
||||||
Status AllreduceFusion::AddEdgeToGraph() {
|
Status AllreduceFusion::AddEdgeToGraph() {
|
||||||
std::unordered_map<CNodePtr, int32_t> cnode_state_map;
|
std::unordered_map<CNodePtr, int32_t> cnode_state_map;
|
||||||
const auto& cnodes = allreduce_graph_.cnode_set();
|
const auto &cnodes = allreduce_graph_.cnode_set();
|
||||||
for (auto& cnode : cnodes) {
|
for (auto &cnode : cnodes) {
|
||||||
cnode_state_map[cnode] = 0;
|
cnode_state_map[cnode] = 0;
|
||||||
}
|
}
|
||||||
const auto& head_cnode = allreduce_graph_.head_cnode();
|
const auto &head_cnode = allreduce_graph_.head_cnode();
|
||||||
std::queue<CNodePtr> cnode_queue;
|
std::queue<CNodePtr> cnode_queue;
|
||||||
cnode_queue.emplace(head_cnode);
|
cnode_queue.emplace(head_cnode);
|
||||||
cnode_state_map[head_cnode] = 1;
|
cnode_state_map[head_cnode] = 1;
|
||||||
|
@ -156,9 +156,9 @@ Status AllreduceFusion::AddEdgeToGraph() {
|
||||||
cnode_queue.pop();
|
cnode_queue.pop();
|
||||||
cnode_state_map[cur_cnode] = 2;
|
cnode_state_map[cur_cnode] = 2;
|
||||||
auto next = FindNextCNodes(cur_cnode);
|
auto next = FindNextCNodes(cur_cnode);
|
||||||
for (auto& ele : next) {
|
for (auto &ele : next) {
|
||||||
auto& cnode = ele.first;
|
auto &cnode = ele.first;
|
||||||
auto& dist = ele.second;
|
auto &dist = ele.second;
|
||||||
if (cnode_state_map[cnode] == 0) {
|
if (cnode_state_map[cnode] == 0) {
|
||||||
cnode_queue.emplace(cnode);
|
cnode_queue.emplace(cnode);
|
||||||
cnode_state_map[cnode] = 1;
|
cnode_state_map[cnode] = 1;
|
||||||
|
@ -173,7 +173,7 @@ Status AllreduceFusion::AddEdgeToGraph() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_times = 0) {
|
std::vector<CNodePtr> FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) {
|
||||||
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
|
||||||
MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is "
|
MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is "
|
||||||
<< MAX_RECURSIVE_CALL_TIMES;
|
<< MAX_RECURSIVE_CALL_TIMES;
|
||||||
|
@ -184,7 +184,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
AnfNodeIndexSet node_set = manager->node_users()[para];
|
AnfNodeIndexSet node_set = manager->node_users()[para];
|
||||||
std::vector<CNodePtr> cnode_list;
|
std::vector<CNodePtr> cnode_list;
|
||||||
for (auto& node_pair : node_set) {
|
for (auto &node_pair : node_set) {
|
||||||
auto cnode = node_pair.first->cast<CNodePtr>();
|
auto cnode = node_pair.first->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||||
|
@ -210,7 +210,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time
|
||||||
return cnode_list;
|
return cnode_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::string& parameter_name) {
|
void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) {
|
||||||
MS_EXCEPTION_IF_NULL(mirror_cnode);
|
MS_EXCEPTION_IF_NULL(mirror_cnode);
|
||||||
MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
|
MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
|
||||||
auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
|
auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
|
||||||
|
@ -227,14 +227,14 @@ void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::st
|
||||||
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
|
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) {
|
Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) {
|
||||||
auto mirror_cnodes = FindMirror(para);
|
auto mirror_cnodes = FindMirror(para);
|
||||||
if (mirror_cnodes.empty()) {
|
if (mirror_cnodes.empty()) {
|
||||||
MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found.";
|
MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found.";
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
if (mirror_cnodes.size() > 2) {
|
if (mirror_cnodes.size() > 2) {
|
||||||
for (auto& mirror_cnode : mirror_cnodes) {
|
for (auto &mirror_cnode : mirror_cnodes) {
|
||||||
MS_EXCEPTION_IF_NULL(mirror_cnode);
|
MS_EXCEPTION_IF_NULL(mirror_cnode);
|
||||||
MS_LOG(INFO) << mirror_cnode->DebugString();
|
MS_LOG(INFO) << mirror_cnode->DebugString();
|
||||||
}
|
}
|
||||||
|
@ -243,15 +243,15 @@ Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) {
|
||||||
<< "Mirror CNode found.";
|
<< "Mirror CNode found.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
for (auto& mirror_cnode : mirror_cnodes) {
|
for (auto &mirror_cnode : mirror_cnodes) {
|
||||||
auto parameter_name = ParameterName(para);
|
auto parameter_name = ParameterName(para);
|
||||||
SetMirrorFusion(mirror_cnode, fusion, parameter_name);
|
SetMirrorFusion(mirror_cnode, fusion, parameter_name);
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusion) {
|
Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> ¶s, int32_t fusion) {
|
||||||
for (auto& param_node : paras) {
|
for (auto ¶m_node : paras) {
|
||||||
if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) {
|
if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
|
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -260,7 +260,7 @@ Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusi
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllreduceFusion::SetFusion(const std::vector<double>& cost_map) {
|
Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) {
|
||||||
if (cost_map.size() < 2) {
|
if (cost_map.size() < 2) {
|
||||||
MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size();
|
MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -386,7 +386,7 @@ Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) {
|
||||||
return SetFusionByBackwardCompAndAllreduceTime();
|
return SetFusionByBackwardCompAndAllreduceTime();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) {
|
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
MS_LOG(ERROR) << "ret is nullptr.";
|
MS_LOG(ERROR) << "ret is nullptr.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
|
|
@ -50,15 +50,15 @@ class AllreduceFusion {
|
||||||
allreduce_bandwidth_(0),
|
allreduce_bandwidth_(0),
|
||||||
computation_time_parameter_(0) {}
|
computation_time_parameter_(0) {}
|
||||||
virtual ~AllreduceFusion() = default;
|
virtual ~AllreduceFusion() = default;
|
||||||
Status ProcessAllreduceFusion(const CNodePtr& ret);
|
Status ProcessAllreduceFusion(const CNodePtr &ret);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status AddNodeToGraph();
|
Status AddNodeToGraph();
|
||||||
CNodeCostMap FindCNode(const AnfNodePtr& from, uint32_t recursive_times = 0) const;
|
CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const;
|
||||||
CNodeCostMap FindNextCNodes(const CNodePtr& from, uint32_t recursive_times = 0) const;
|
CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const;
|
||||||
Status AddEdgeToGraph();
|
Status AddEdgeToGraph();
|
||||||
std::vector<double> GenerateCostMap(int32_t fusion_times, double tail_percent) const;
|
std::vector<double> GenerateCostMap(int32_t fusion_times, double tail_percent) const;
|
||||||
Status SetFusion(const std::vector<double>& cost_map);
|
Status SetFusion(const std::vector<double> &cost_map);
|
||||||
Status SetFusionByAlgorithm(int32_t algorithm);
|
Status SetFusionByAlgorithm(int32_t algorithm);
|
||||||
Status SetFusionByBackwardCompTime();
|
Status SetFusionByBackwardCompTime();
|
||||||
Status SetFusionByBackwardCompAndAllreduceTime();
|
Status SetFusionByBackwardCompAndAllreduceTime();
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
|
Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) {
|
||||||
AllreduceNodePtr arnode;
|
AllreduceNodePtr arnode;
|
||||||
auto cnode_emplace_return = cnode_set_.emplace(node);
|
auto cnode_emplace_return = cnode_set_.emplace(node);
|
||||||
if (!cnode_emplace_return.second) {
|
if (!cnode_emplace_return.second) {
|
||||||
|
@ -64,7 +64,7 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double dist) {
|
Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) {
|
||||||
auto from_arnode_iter = cnode_arnode_map_.find(from);
|
auto from_arnode_iter = cnode_arnode_map_.find(from);
|
||||||
if (from_arnode_iter == cnode_arnode_map_.end()) {
|
if (from_arnode_iter == cnode_arnode_map_.end()) {
|
||||||
MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
|
MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
|
||||||
|
@ -94,14 +94,14 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AllreduceGraph::NodeInGraph(const CNodePtr& node) const {
|
bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const {
|
||||||
auto cnode_iter = cnode_set_.find(node);
|
auto cnode_iter = cnode_set_.find(node);
|
||||||
return !(cnode_iter == cnode_set_.end());
|
return !(cnode_iter == cnode_set_.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
|
std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
|
||||||
std::vector<AnfNodePtr> nodes;
|
std::vector<AnfNodePtr> nodes;
|
||||||
for (auto& cnode_arnode : cnode_arnode_map_) {
|
for (auto &cnode_arnode : cnode_arnode_map_) {
|
||||||
MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
|
MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
|
||||||
<< ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
|
<< ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
|
||||||
<< " curr_para_size: " << cnode_arnode.second->curr_para_size();
|
<< " curr_para_size: " << cnode_arnode.second->curr_para_size();
|
||||||
|
@ -117,7 +117,7 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou
|
||||||
std::vector<AnfNodePtr> nodes;
|
std::vector<AnfNodePtr> nodes;
|
||||||
double cur_para_size = 0;
|
double cur_para_size = 0;
|
||||||
double from = to;
|
double from = to;
|
||||||
for (auto& arnode : arnode_vec_) {
|
for (auto &arnode : arnode_vec_) {
|
||||||
if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) {
|
if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -135,14 +135,14 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou
|
||||||
|
|
||||||
void AllreduceGraph::PrintCNodeSet() const {
|
void AllreduceGraph::PrintCNodeSet() const {
|
||||||
MS_LOG(INFO) << "CNodeSet:";
|
MS_LOG(INFO) << "CNodeSet:";
|
||||||
for (auto& cnode : cnode_set_) {
|
for (auto &cnode : cnode_set_) {
|
||||||
MS_LOG(INFO) << cnode->DebugString();
|
MS_LOG(INFO) << cnode->DebugString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllreduceGraph::PrintAllredueGraphInfo() const {
|
void AllreduceGraph::PrintAllredueGraphInfo() const {
|
||||||
MS_LOG(INFO) << "max: " << max_;
|
MS_LOG(INFO) << "max: " << max_;
|
||||||
for (auto& cnode_arnode : cnode_arnode_map_) {
|
for (auto &cnode_arnode : cnode_arnode_map_) {
|
||||||
MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
|
MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
|
||||||
MS_LOG(INFO) << "arnode info: ";
|
MS_LOG(INFO) << "arnode info: ";
|
||||||
cnode_arnode.second->ToString();
|
cnode_arnode.second->ToString();
|
||||||
|
@ -151,21 +151,21 @@ void AllreduceGraph::PrintAllredueGraphInfo() const {
|
||||||
|
|
||||||
void AllreduceGraph::PrintArnodeVec() const {
|
void AllreduceGraph::PrintArnodeVec() const {
|
||||||
MS_LOG(INFO) << "ArnodeVec:";
|
MS_LOG(INFO) << "ArnodeVec:";
|
||||||
for (auto& arnode : arnode_vec_) {
|
for (auto &arnode : arnode_vec_) {
|
||||||
arnode.ToString();
|
arnode.ToString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllreduceGraph::PrintArnodeSet() const {
|
void AllreduceGraph::PrintArnodeSet() const {
|
||||||
MS_LOG(INFO) << "ArnodeSet:";
|
MS_LOG(INFO) << "ArnodeSet:";
|
||||||
for (auto& arnode : arnode_set_) {
|
for (auto &arnode : arnode_set_) {
|
||||||
arnode->ToString();
|
arnode->ToString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllreduceGraph::SortArnode() {
|
void AllreduceGraph::SortArnode() {
|
||||||
arnode_vec_.clear();
|
arnode_vec_.clear();
|
||||||
for (auto& node : arnode_set_) {
|
for (auto &node : arnode_set_) {
|
||||||
arnode_vec_.emplace_back(*node);
|
arnode_vec_.emplace_back(*node);
|
||||||
}
|
}
|
||||||
std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
|
std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
|
||||||
|
@ -173,8 +173,8 @@ void AllreduceGraph::SortArnode() {
|
||||||
|
|
||||||
Status AllreduceGraph::RemoveExtraParas() {
|
Status AllreduceGraph::RemoveExtraParas() {
|
||||||
std::unordered_set<AnfNodePtr> para_map;
|
std::unordered_set<AnfNodePtr> para_map;
|
||||||
for (auto& node : arnode_vec_) {
|
for (auto &node : arnode_vec_) {
|
||||||
for (auto& para : node.paras()) {
|
for (auto ¶ : node.paras()) {
|
||||||
auto emplac_result = para_map.emplace(para);
|
auto emplac_result = para_map.emplace(para);
|
||||||
if (!emplac_result.second) {
|
if (!emplac_result.second) {
|
||||||
MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
|
MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
|
||||||
|
@ -188,7 +188,7 @@ Status AllreduceGraph::RemoveExtraParas() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllreduceGraph::set_head_cnode(const CNodePtr& node) {
|
Status AllreduceGraph::set_head_cnode(const CNodePtr &node) {
|
||||||
auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
|
auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
|
||||||
if (arnode->Init(node) != SUCCESS) {
|
if (arnode->Init(node) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << "AllreduceNode Init failed";
|
MS_LOG(ERROR) << "AllreduceNode Init failed";
|
||||||
|
|
|
@ -42,9 +42,9 @@ class AllreduceGraph {
|
||||||
cnode_arnode_map_(),
|
cnode_arnode_map_(),
|
||||||
max_(0) {}
|
max_(0) {}
|
||||||
virtual ~AllreduceGraph() = default;
|
virtual ~AllreduceGraph() = default;
|
||||||
Status AddNode(const CNodePtr& node, const AnfNodePtr& para);
|
Status AddNode(const CNodePtr &node, const AnfNodePtr ¶);
|
||||||
Status AddEdge(const CNodePtr& from, const CNodePtr& to, double dist);
|
Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist);
|
||||||
bool NodeInGraph(const CNodePtr& node) const;
|
bool NodeInGraph(const CNodePtr &node) const;
|
||||||
std::vector<AnfNodePtr> GetParaByCost(double from, double to);
|
std::vector<AnfNodePtr> GetParaByCost(double from, double to);
|
||||||
// Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is
|
// Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is
|
||||||
// over para_size.
|
// over para_size.
|
||||||
|
@ -60,9 +60,9 @@ class AllreduceGraph {
|
||||||
void PrintAllredueGraphInfo() const;
|
void PrintAllredueGraphInfo() const;
|
||||||
void PrintArnodeVec() const;
|
void PrintArnodeVec() const;
|
||||||
void PrintArnodeSet() const;
|
void PrintArnodeSet() const;
|
||||||
const std::unordered_set<CNodePtr>& cnode_set() const { return cnode_set_; }
|
const std::unordered_set<CNodePtr> &cnode_set() const { return cnode_set_; }
|
||||||
CNodePtr head_cnode() const { return head_cnode_; }
|
CNodePtr head_cnode() const { return head_cnode_; }
|
||||||
Status set_head_cnode(const CNodePtr& node);
|
Status set_head_cnode(const CNodePtr &node);
|
||||||
double max() const { return max_; }
|
double max() const { return max_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue