forked from mindspore-Ecosystem/mindspore
move default_param out of parameter and remove pybind11 in anf define
This commit is contained in:
parent
9b8b699eb3
commit
40e15996b0
|
@ -26,6 +26,7 @@
|
|||
#include "utils/graph_utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
#include "pipeline/parse/resolve.h"
|
||||
#include "operator/composite/composite.h"
|
||||
|
@ -469,7 +470,8 @@ void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNode
|
|||
MS_LOG(EXCEPTION) << "Param could not cast to parameter";
|
||||
}
|
||||
if (param_ptr->has_default()) {
|
||||
ofs << " = @" << DumpObject(param_ptr->default_param(), "D");
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
|
||||
ofs << " = @" << DumpObject(param_value->value(), "D");
|
||||
}
|
||||
|
||||
// output comment
|
||||
|
@ -1650,7 +1652,8 @@ class IrParser {
|
|||
|
||||
// load parameter default value from serialized file
|
||||
py::object default_obj = LoadObject(lexer_.GetTokenText());
|
||||
param->set_default_param(default_obj);
|
||||
auto param_value_new = std::make_shared<ParamValuePy>(default_obj);
|
||||
param->set_default_param(param_value_new);
|
||||
|
||||
tok = lexer_.GetNextToken();
|
||||
}
|
||||
|
|
|
@ -21,12 +21,17 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "operator/composite/composite.h"
|
||||
#include "ir/meta_tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
// namespace to support debug utils
|
||||
|
@ -312,17 +317,21 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
|
|||
for (auto ¶meter : key->parameters()) {
|
||||
buffer_ << "<tr><td>";
|
||||
buffer_ << parameter->ToString();
|
||||
auto py_p = dyn_cast<Parameter>(parameter)->default_param();
|
||||
if (py::hasattr(py_p, "default_input")) {
|
||||
py_p = py_p.attr("default_input");
|
||||
if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) {
|
||||
std::shared_ptr<tensor::Tensor> m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>();
|
||||
py::tuple shape = m_tensor->GetPyTupleShape();
|
||||
buffer_ << "[" << std::string(py::str(shape)) << "]";
|
||||
} else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) {
|
||||
std::shared_ptr<tensor::MetaTensor> m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>();
|
||||
py::tuple shape = m_tensor->GetPyTupleShape();
|
||||
buffer_ << "[" << std::string(py::str(shape)) << "]";
|
||||
auto param = parameter->cast<ParameterPtr>();
|
||||
if (param->has_default()) {
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
|
||||
auto py_p = param_value->value();
|
||||
if (py::hasattr(py_p, "default_input")) {
|
||||
py_p = py_p.attr("default_input");
|
||||
if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) {
|
||||
auto m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>();
|
||||
py::tuple shape = m_tensor->GetPyTupleShape();
|
||||
buffer_ << "[" << py::str(shape) << "]";
|
||||
} else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) {
|
||||
auto m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>();
|
||||
py::tuple shape = m_tensor->GetPyTupleShape();
|
||||
buffer_ << "[" << py::str(shape) << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
buffer_ << "</td></tr>";
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
#include <utility>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <climits>
|
||||
#include "ir/anf.h"
|
||||
#include "pipeline/parse/parse.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {
|
||||
|
|
|
@ -24,13 +24,10 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/base.h"
|
||||
#include "debug/trace_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
||||
// namespace to support intermediate representation definition
|
||||
enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 };
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "utils/any.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -19,8 +19,6 @@
|
|||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include "ir/anf.h"
|
||||
#include "pipeline/parse/parse.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {
|
||||
|
|
|
@ -24,12 +24,9 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
||||
class TraceInfo;
|
||||
using TraceInfoPtr = std::shared_ptr<TraceInfo>;
|
||||
class Location;
|
||||
|
|
|
@ -23,21 +23,11 @@
|
|||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ir/visitor.h"
|
||||
#include "pipeline/static_analysis/static_analysis.h"
|
||||
#include "operator/ops.h"
|
||||
#include "parallel/ops_info/ops_utils.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support intermediate representation definition
|
||||
// Methods of AnfNode
|
||||
TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); }
|
||||
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
|
||||
|
||||
std::string AnfNode::ToString() const {
|
||||
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
|
||||
}
|
||||
|
||||
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
|
||||
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
|
||||
|
||||
|
@ -85,66 +75,6 @@ std::string CNode::DebugString(int recursive_level) const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
|
||||
if (operator_info_ != nullptr) {
|
||||
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
|
||||
<< ", using the new one: " << operator_info->name();
|
||||
auto old_ptr = operator_info_;
|
||||
operator_info_ = operator_info;
|
||||
return old_ptr;
|
||||
}
|
||||
operator_info_ = operator_info;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string CNode::fullname_with_scope() {
|
||||
// if full name is set, return its name immediately
|
||||
if (!fullname_with_scope_.empty()) {
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
|
||||
IsApply(prim::kPrimHistogramSummary)) {
|
||||
std::string tag = GetValue<std::string>(GetValueNode(input(1)));
|
||||
if (tag == "") {
|
||||
MS_LOG(EXCEPTION) << "The tag name is null, should be valid string";
|
||||
}
|
||||
std::string name;
|
||||
if (IsApply(prim::kPrimScalarSummary)) {
|
||||
name = tag + "[:Scalar]";
|
||||
} else if (IsApply(prim::kPrimImageSummary)) {
|
||||
name = tag + "[:Image]";
|
||||
} else if (IsApply(prim::kPrimHistogramSummary)) {
|
||||
name = tag + "[:Histogram]";
|
||||
} else {
|
||||
name = tag + "[:Tensor]";
|
||||
}
|
||||
fullname_with_scope_ = name;
|
||||
} else {
|
||||
// cnode input 0 should be primitive ptr
|
||||
auto value_ptr = input(0)->cast<ValueNodePtr>();
|
||||
if (value_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
|
||||
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
auto input_value = value_ptr->value();
|
||||
if (input_value == nullptr) {
|
||||
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
|
||||
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
|
||||
MS_EXCEPTION_IF_NULL(scope());
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
fullname_with_scope_ =
|
||||
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
|
||||
}
|
||||
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
std::string ValueNode::ToString() const {
|
||||
MS_EXCEPTION_IF_NULL(value_);
|
||||
if (value_->isa<FuncGraph>()) {
|
||||
|
@ -173,10 +103,6 @@ std::string ValueNode::fullname_with_scope() {
|
|||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
|
||||
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
|
||||
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
|
||||
|
||||
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
|
|
@ -52,6 +52,7 @@ class AbstractBase;
|
|||
} // namespace abstract
|
||||
using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
|
||||
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
|
||||
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
|
||||
|
||||
class ValueNode;
|
||||
using ValueNodePtr = std::shared_ptr<ValueNode>;
|
||||
|
@ -78,6 +79,13 @@ using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
|
|||
|
||||
class AnfVisitor;
|
||||
|
||||
class ParamValue {
|
||||
public:
|
||||
ParamValue() = default;
|
||||
virtual ~ParamValue() = default;
|
||||
};
|
||||
using ParamValuePtr = std::shared_ptr<ParamValue>;
|
||||
|
||||
// AnfNode is the basic class of the IR definition derived from Base.
|
||||
// Only two types of nodes are derived: CNode and ANode.
|
||||
// Methods:
|
||||
|
@ -239,11 +247,11 @@ class ANode : public AnfNode {
|
|||
|
||||
// Parameter represents the parameter inputs of a function. They have no value.
|
||||
// Attributes:
|
||||
// default_param_: used to hold the inputting tensor of the model.
|
||||
// default_param_value_: used to hold the inputting tensor of the model.
|
||||
class Parameter : public ANode {
|
||||
public:
|
||||
explicit Parameter(const FuncGraphPtr &func_graph)
|
||||
: ANode(func_graph), name_(""), has_default_(false), default_param_(py::none()), tensor_layout_(nullptr) {}
|
||||
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {}
|
||||
~Parameter() override = default;
|
||||
MS_DECLARE_PARENT(Parameter, ANode);
|
||||
|
||||
|
@ -254,12 +262,11 @@ class Parameter : public ANode {
|
|||
std::string fullname_with_scope() override { return name(); };
|
||||
|
||||
bool has_default() const { return has_default_; }
|
||||
|
||||
py::object default_param() { return default_param_; }
|
||||
void set_default_param(const py::object &obj) {
|
||||
default_param_ = obj;
|
||||
void set_default_param(ParamValuePtr param) {
|
||||
default_param_ = param;
|
||||
has_default_ = true;
|
||||
}
|
||||
ParamValuePtr default_param() const { return default_param_; }
|
||||
|
||||
std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; }
|
||||
void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) {
|
||||
|
@ -280,7 +287,7 @@ class Parameter : public ANode {
|
|||
private:
|
||||
std::string name_;
|
||||
bool has_default_;
|
||||
py::object default_param_;
|
||||
ParamValuePtr default_param_;
|
||||
std::shared_ptr<parallel::TensorLayout> tensor_layout_;
|
||||
};
|
||||
using ParameterPtr = std::shared_ptr<Parameter>;
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/anf.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ir/visitor.h"
|
||||
#include "pipeline/static_analysis/static_analysis.h"
|
||||
#include "operator/ops.h"
|
||||
#include "parallel/ops_info/ops_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support intermediate representation definition
|
||||
// Methods of AnfNode
|
||||
TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); }
|
||||
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
|
||||
|
||||
std::string AnfNode::ToString() const {
|
||||
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
|
||||
}
|
||||
|
||||
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
|
||||
if (operator_info_ != nullptr) {
|
||||
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
|
||||
<< ", using the new one: " << operator_info->name();
|
||||
auto old_ptr = operator_info_;
|
||||
operator_info_ = operator_info;
|
||||
return old_ptr;
|
||||
}
|
||||
operator_info_ = operator_info;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string CNode::fullname_with_scope() {
|
||||
// if full name is set, return its name immediately
|
||||
if (!fullname_with_scope_.empty()) {
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
|
||||
IsApply(prim::kPrimHistogramSummary)) {
|
||||
std::string tag = GetValue<std::string>(GetValueNode(input(1)));
|
||||
if (tag == "") {
|
||||
MS_LOG(EXCEPTION) << "The tag name is null, should be valid string";
|
||||
}
|
||||
std::string name;
|
||||
if (IsApply(prim::kPrimScalarSummary)) {
|
||||
name = tag + "[:Scalar]";
|
||||
} else if (IsApply(prim::kPrimImageSummary)) {
|
||||
name = tag + "[:Image]";
|
||||
} else if (IsApply(prim::kPrimHistogramSummary)) {
|
||||
name = tag + "[:Histogram]";
|
||||
} else {
|
||||
name = tag + "[:Tensor]";
|
||||
}
|
||||
fullname_with_scope_ = name;
|
||||
} else {
|
||||
// cnode input 0 should be primitive ptr
|
||||
auto value_ptr = input(0)->cast<ValueNodePtr>();
|
||||
if (value_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
|
||||
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
auto input_value = value_ptr->value();
|
||||
if (input_value == nullptr) {
|
||||
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
|
||||
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
|
||||
MS_EXCEPTION_IF_NULL(scope());
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
fullname_with_scope_ =
|
||||
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
|
||||
}
|
||||
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
|
||||
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
|
||||
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
|
||||
|
||||
} // namespace mindspore
|
|
@ -19,9 +19,6 @@
|
|||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
TypePtr Keyword::DeepCopy() const {
|
||||
|
@ -206,8 +203,6 @@ std::string Function::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
TypePtr TypeAnything::DeepCopy() const { return kAnyType; }
|
||||
|
||||
TypePtr JTagged::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(subtype_);
|
||||
if (IsGeneric()) {
|
||||
|
@ -247,460 +242,4 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble
|
|||
os << problem->ToString();
|
||||
return os;
|
||||
}
|
||||
|
||||
std::size_t TypeHasher::operator()(TypePtr const &type) const {
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
std::size_t hash = std::hash<size_t>()(type->type_id());
|
||||
return hash;
|
||||
}
|
||||
|
||||
std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
|
||||
std::size_t hash_sum = 0;
|
||||
for (auto &type : type_list) {
|
||||
auto type_id = static_cast<std::size_t>(type->type_id());
|
||||
hash_sum = hash_combine(hash_sum, type_id);
|
||||
}
|
||||
return hash_sum;
|
||||
}
|
||||
|
||||
bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
MS_EXCEPTION_IF_NULL(t2);
|
||||
return t1->type_id() == t2->type_id();
|
||||
}
|
||||
|
||||
bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
std::size_t size = lhs.size();
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
MS_EXCEPTION_IF_NULL(lhs[i]);
|
||||
MS_EXCEPTION_IF_NULL(rhs[i]);
|
||||
if (*lhs[i] != *rhs[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
TypePtr TypeIdToType(TypeId id) {
|
||||
switch (id) {
|
||||
case kNumberTypeFloat16:
|
||||
return kFloat16;
|
||||
case kNumberTypeFloat:
|
||||
case kNumberTypeFloat32:
|
||||
return kFloat32;
|
||||
case kNumberTypeFloat64:
|
||||
return kFloat64;
|
||||
case kNumberTypeInt8:
|
||||
return kInt8;
|
||||
case kNumberTypeInt16:
|
||||
return kInt16;
|
||||
case kNumberTypeInt32:
|
||||
return kInt32;
|
||||
case kNumberTypeInt64:
|
||||
return kInt64;
|
||||
case kNumberTypeUInt8:
|
||||
return kUInt8;
|
||||
case kNumberTypeUInt16:
|
||||
return kUInt16;
|
||||
case kNumberTypeUInt32:
|
||||
return kUInt32;
|
||||
case kNumberTypeUInt64:
|
||||
return kUInt64;
|
||||
case kNumberTypeBool:
|
||||
return kBool;
|
||||
case kMetaTypeExternal:
|
||||
return kTypeExternal;
|
||||
case kMetaTypeAnything:
|
||||
return kAnyType;
|
||||
case kMetaTypeNone:
|
||||
return kTypeNone;
|
||||
case kObjectTypeEnvType:
|
||||
return kTypeEnv;
|
||||
case kObjectTypeRefKey:
|
||||
return kRefKeyType;
|
||||
case kObjectTypeRef:
|
||||
return kRefType;
|
||||
case kTypeUnknown:
|
||||
return kTypeNone;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Not support the type: " << id;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == num_type_name) {
|
||||
type = std::make_shared<T>();
|
||||
} else {
|
||||
try {
|
||||
if (num_type_name.size() >= type_name.size()) {
|
||||
MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name
|
||||
<< ")";
|
||||
}
|
||||
auto bits = std::stoi(type_name.substr(num_type_name.size()));
|
||||
type = std::make_shared<T>(bits);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
|
||||
std::vector<TypePtr> types;
|
||||
if (type_names.length() == 0) {
|
||||
return types;
|
||||
}
|
||||
std::string::size_type start = 0;
|
||||
std::string::size_type end = type_names.find_first_of(',');
|
||||
while (end != std::string::npos) {
|
||||
types.push_back(StringToType(type_names.substr(start, end)));
|
||||
// Skip ',' to find the next element.
|
||||
start = end + 1;
|
||||
end = type_names.find_first_of(',', start);
|
||||
}
|
||||
if (start >= type_names.size()) {
|
||||
MS_LOG(EXCEPTION) << "Type name is empty string.";
|
||||
}
|
||||
types.push_back(StringToType(type_names.substr(start)));
|
||||
return types;
|
||||
}
|
||||
|
||||
TypePtr TensorStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "Tensor") {
|
||||
type = std::make_shared<TensorType>();
|
||||
} else {
|
||||
try {
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto element_str = type_name.substr(start, end - start);
|
||||
auto element_type = StringToType(element_str);
|
||||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
type = std::make_shared<TensorType>(element_type);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
TypePtr ListStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "List") {
|
||||
type = std::make_shared<List>();
|
||||
} else {
|
||||
try {
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string element_strs = type_name.substr(start, end - start);
|
||||
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
|
||||
bool wrong =
|
||||
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||
if (wrong) {
|
||||
return nullptr;
|
||||
}
|
||||
type = std::make_shared<List>(element_types);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
TypePtr TupleStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "Tuple") {
|
||||
type = std::make_shared<Tuple>();
|
||||
} else {
|
||||
try {
|
||||
size_t start = type_name.find_first_of('[') + 1;
|
||||
size_t end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string element_strs = type_name.substr(start, end - start);
|
||||
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
|
||||
bool wrong =
|
||||
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||
if (wrong) {
|
||||
return nullptr;
|
||||
}
|
||||
type = std::make_shared<Tuple>(element_types);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
TypePtr FunctionStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
|
||||
if (type_name == "Function") {
|
||||
type = std::make_shared<Function>();
|
||||
} else {
|
||||
try {
|
||||
// format: [(para1, para2, para3, ...) retval]
|
||||
size_t start = type_name.find_first_of('[') + 1;
|
||||
size_t end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string str_all = type_name.substr(start, end - start);
|
||||
size_t start_a = str_all.find_first_of('(') + 1;
|
||||
size_t end_a = str_all.find_last_of(')');
|
||||
if (start_a >= str_all.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string str_args = str_all.substr(start_a, end_a - start_a);
|
||||
// bypass " " between ")" and retval
|
||||
start = end_a + 2;
|
||||
if (start >= str_all.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string str_retval = str_all.substr(start);
|
||||
|
||||
std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
|
||||
TypePtr retval = StringToType(str_retval);
|
||||
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||
if (retval == nullptr || wrong) {
|
||||
return nullptr;
|
||||
}
|
||||
type = std::make_shared<Function>(args_type, retval);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TypePtr StringToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name.compare("None") == 0) {
|
||||
type = std::make_shared<TypeNone>();
|
||||
} else if (type_name.compare("Ellipsis") == 0) {
|
||||
type = std::make_shared<Ellipsis>();
|
||||
} else if (type_name.compare("TypeType") == 0) {
|
||||
type = std::make_shared<TypeType>();
|
||||
} else if (type_name.compare("SymbolicKeyType") == 0) {
|
||||
type = std::make_shared<SymbolicKeyType>();
|
||||
} else if (type_name.compare("RefKeyType") == 0) {
|
||||
type = std::make_shared<RefKeyType>();
|
||||
} else if (type_name.compare("EnvType") == 0) {
|
||||
type = std::make_shared<EnvType>();
|
||||
} else if (type_name.compare("Number") == 0) {
|
||||
type = std::make_shared<Number>();
|
||||
} else if (type_name.compare("Bool") == 0) {
|
||||
type = std::make_shared<Bool>();
|
||||
} else if (type_name.compare(0, strlen("Int"), "Int") == 0) {
|
||||
type = StringToNumberType<Int>(type_name, "Int");
|
||||
} else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) {
|
||||
type = StringToNumberType<UInt>(type_name, "UInt");
|
||||
} else if (type_name.compare(0, strlen("Float"), "Float") == 0) {
|
||||
type = StringToNumberType<Float>(type_name, "Float");
|
||||
} else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
|
||||
type = TensorStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
||||
type = ListStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
||||
type = TupleStrToType(type_name);
|
||||
} else if (type_name.compare("Slice") == 0) {
|
||||
type = std::make_shared<Slice>();
|
||||
} else if (type_name.compare("Dictionary") == 0) {
|
||||
type = std::make_shared<Dictionary>();
|
||||
} else if (type_name.compare("String") == 0) {
|
||||
type = std::make_shared<String>();
|
||||
} else if (type_name.compare("Problem") == 0) {
|
||||
type = std::make_shared<Problem>();
|
||||
} else if (type_name.compare(0, strlen("Function"), "Function") == 0) {
|
||||
type = FunctionStrToType(type_name);
|
||||
} else {
|
||||
// - unsupported to convert
|
||||
// Class
|
||||
// SymbolicType
|
||||
// JTagged
|
||||
// Anything
|
||||
// External
|
||||
// Problem
|
||||
MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!";
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
||||
if (x == nullptr || base_type == nullptr) {
|
||||
MS_LOG(ERROR) << "Type is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
|
||||
return false;
|
||||
} else if (!(base_type->IsGeneric())) {
|
||||
return *(base_type) == *(x);
|
||||
} else if (base_type->type_id() == x->type_id()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->generic_type_id()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->object_type()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->meta_type()) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
if (t1->type_id() == kTypeUnknown) {
|
||||
return false;
|
||||
} else if (t2 != nullptr) {
|
||||
return IsIdentidityOrSubclass(t1, t2);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
typing, ([](py::module *const m) {
|
||||
auto m_sub = m->def_submodule("typing", "submodule for dtype");
|
||||
py::enum_<TypeId>(m_sub, "TypeId");
|
||||
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
|
||||
(void)m_sub.def("load_type", &TypeIdToType, "load type");
|
||||
(void)m_sub.def(
|
||||
"dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
|
||||
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
|
||||
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
|
||||
.def("__eq__",
|
||||
[](const TypePtr &t1, const TypePtr &t2) {
|
||||
if (t1 != nullptr && t2 != nullptr) {
|
||||
return *t1 == *t2;
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.def("__hash__", &Type::hash)
|
||||
.def("__str__", &Type::ToString)
|
||||
.def("__repr__", &Type::ReprString)
|
||||
.def("__deepcopy__", [](const TypePtr &t, py::dict) {
|
||||
if (t == nullptr) {
|
||||
return static_cast<TypePtr>(nullptr);
|
||||
}
|
||||
return t->DeepCopy();
|
||||
});
|
||||
(void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init());
|
||||
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
|
||||
.def(py::init())
|
||||
.def(py::pickle(
|
||||
[](const Bool &) { // __getstate__
|
||||
return py::make_tuple();
|
||||
},
|
||||
[](const py::tuple &) { // __setstate__
|
||||
return std::make_shared<Bool>();
|
||||
}));
|
||||
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
[](const Int &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(t.nbits()));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
Int data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
[](const UInt &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(t.nbits()));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
UInt data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
[](const Float &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(t.nbits()));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
Float data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<List, Type, std::shared_ptr<List>>(m_sub, "List")
|
||||
.def(py::init())
|
||||
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
|
||||
(void)py::class_<Tuple, Type, std::shared_ptr<Tuple>>(m_sub, "Tuple")
|
||||
.def(py::init())
|
||||
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
|
||||
(void)py::class_<TensorType, Type, std::shared_ptr<TensorType>>(m_sub, "TensorType")
|
||||
.def(py::init())
|
||||
.def(py::init<TypePtr>(), py::arg("element"))
|
||||
.def("element_type", &TensorType::element)
|
||||
.def(py::pickle(
|
||||
[](const TensorType &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
||||
.def(py::init())
|
||||
.def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval"));
|
||||
(void)py::class_<Class, Type, std::shared_ptr<Class>>(m_sub, "Class").def(py::init());
|
||||
(void)py::class_<SymbolicKeyType, Type, std::shared_ptr<SymbolicKeyType>>(m_sub, "SymbolicKeyType").def(py::init());
|
||||
(void)py::class_<EnvType, Type, std::shared_ptr<EnvType>>(m_sub, "EnvType").def(py::init());
|
||||
(void)py::class_<TypeNone, Type, std::shared_ptr<TypeNone>>(m_sub, "TypeNone").def(py::init());
|
||||
(void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init());
|
||||
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
|
||||
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
|
||||
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
|
||||
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
|
||||
}));
|
||||
|
||||
const TypePtr kTypeExternal = std::make_shared<External>();
|
||||
const TypePtr kTypeEnv = std::make_shared<EnvType>();
|
||||
const TypePtr kTypeType = std::make_shared<TypeType>();
|
||||
const TypePtr kTensorType = std::make_shared<TensorType>();
|
||||
const TypePtr kString = std::make_shared<String>();
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,9 +19,6 @@
|
|||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {
|
||||
|
|
|
@ -19,9 +19,6 @@
|
|||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
bool Number::operator==(const Type &other) const {
|
||||
|
|
|
@ -19,9 +19,6 @@
|
|||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
TypePtr RefType::DeepCopy() const {
|
||||
|
|
|
@ -21,9 +21,8 @@
|
|||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
#include "ir/dtype/number.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
TypeId IntBitsToTypeId(const int nbits) {
|
||||
|
@ -227,11 +226,6 @@ bool Type::operator==(const Value &other) const {
|
|||
}
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr Type::ToAbstract() {
|
||||
abstract::AbstractBasePtr ptr = std::make_shared<abstract::AbstractType>(shared_from_base<Type>());
|
||||
return ptr;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const Type &type) {
|
||||
os << type.ToString();
|
||||
return os;
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/dtype/type.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
abstract::AbstractBasePtr Type::ToAbstract() {
|
||||
auto ptr = std::make_shared<abstract::AbstractType>(shared_from_base<Type>());
|
||||
return ptr;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,484 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/dtype.h"
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
TypePtr TypeAnything::DeepCopy() const { return kAnyType; }
|
||||
|
||||
std::size_t TypeHasher::operator()(TypePtr const &type) const {
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
std::size_t hash = std::hash<size_t>()(type->type_id());
|
||||
return hash;
|
||||
}
|
||||
|
||||
std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
|
||||
std::size_t hash_sum = 0;
|
||||
for (auto &type : type_list) {
|
||||
auto type_id = static_cast<std::size_t>(type->type_id());
|
||||
hash_sum = hash_combine(hash_sum, type_id);
|
||||
}
|
||||
return hash_sum;
|
||||
}
|
||||
|
||||
bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
MS_EXCEPTION_IF_NULL(t2);
|
||||
return t1->type_id() == t2->type_id();
|
||||
}
|
||||
|
||||
bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
std::size_t size = lhs.size();
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
MS_EXCEPTION_IF_NULL(lhs[i]);
|
||||
MS_EXCEPTION_IF_NULL(rhs[i]);
|
||||
if (*lhs[i] != *rhs[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
TypePtr TypeIdToType(TypeId id) {
|
||||
switch (id) {
|
||||
case kNumberTypeFloat16:
|
||||
return kFloat16;
|
||||
case kNumberTypeFloat:
|
||||
case kNumberTypeFloat32:
|
||||
return kFloat32;
|
||||
case kNumberTypeFloat64:
|
||||
return kFloat64;
|
||||
case kNumberTypeInt8:
|
||||
return kInt8;
|
||||
case kNumberTypeInt16:
|
||||
return kInt16;
|
||||
case kNumberTypeInt32:
|
||||
return kInt32;
|
||||
case kNumberTypeInt64:
|
||||
return kInt64;
|
||||
case kNumberTypeUInt8:
|
||||
return kUInt8;
|
||||
case kNumberTypeUInt16:
|
||||
return kUInt16;
|
||||
case kNumberTypeUInt32:
|
||||
return kUInt32;
|
||||
case kNumberTypeUInt64:
|
||||
return kUInt64;
|
||||
case kNumberTypeBool:
|
||||
return kBool;
|
||||
case kMetaTypeExternal:
|
||||
return kTypeExternal;
|
||||
case kMetaTypeAnything:
|
||||
return kAnyType;
|
||||
case kMetaTypeNone:
|
||||
return kTypeNone;
|
||||
case kObjectTypeEnvType:
|
||||
return kTypeEnv;
|
||||
case kObjectTypeRefKey:
|
||||
return kRefKeyType;
|
||||
case kObjectTypeRef:
|
||||
return kRefType;
|
||||
case kTypeUnknown:
|
||||
return kTypeNone;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Not support the type: " << id;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == num_type_name) {
|
||||
type = std::make_shared<T>();
|
||||
} else {
|
||||
try {
|
||||
if (num_type_name.size() >= type_name.size()) {
|
||||
MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name
|
||||
<< ")";
|
||||
}
|
||||
auto bits = std::stoi(type_name.substr(num_type_name.size()));
|
||||
type = std::make_shared<T>(bits);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
|
||||
std::vector<TypePtr> types;
|
||||
if (type_names.length() == 0) {
|
||||
return types;
|
||||
}
|
||||
std::string::size_type start = 0;
|
||||
std::string::size_type end = type_names.find_first_of(',');
|
||||
while (end != std::string::npos) {
|
||||
types.push_back(StringToType(type_names.substr(start, end)));
|
||||
// Skip ',' to find the next element.
|
||||
start = end + 1;
|
||||
end = type_names.find_first_of(',', start);
|
||||
}
|
||||
if (start >= type_names.size()) {
|
||||
MS_LOG(EXCEPTION) << "Type name is empty string.";
|
||||
}
|
||||
types.push_back(StringToType(type_names.substr(start)));
|
||||
return types;
|
||||
}
|
||||
|
||||
TypePtr TensorStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "Tensor") {
|
||||
type = std::make_shared<TensorType>();
|
||||
} else {
|
||||
try {
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto element_str = type_name.substr(start, end - start);
|
||||
auto element_type = StringToType(element_str);
|
||||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
type = std::make_shared<TensorType>(element_type);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
TypePtr ListStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "List") {
|
||||
type = std::make_shared<List>();
|
||||
} else {
|
||||
try {
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string element_strs = type_name.substr(start, end - start);
|
||||
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
|
||||
bool wrong =
|
||||
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||
if (wrong) {
|
||||
return nullptr;
|
||||
}
|
||||
type = std::make_shared<List>(element_types);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
TypePtr TupleStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "Tuple") {
|
||||
type = std::make_shared<Tuple>();
|
||||
} else {
|
||||
try {
|
||||
size_t start = type_name.find_first_of('[') + 1;
|
||||
size_t end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string element_strs = type_name.substr(start, end - start);
|
||||
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
|
||||
bool wrong =
|
||||
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||
if (wrong) {
|
||||
return nullptr;
|
||||
}
|
||||
type = std::make_shared<Tuple>(element_types);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
TypePtr FunctionStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
|
||||
if (type_name == "Function") {
|
||||
type = std::make_shared<Function>();
|
||||
} else {
|
||||
try {
|
||||
// format: [(para1, para2, para3, ...) retval]
|
||||
size_t start = type_name.find_first_of('[') + 1;
|
||||
size_t end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string str_all = type_name.substr(start, end - start);
|
||||
size_t start_a = str_all.find_first_of('(') + 1;
|
||||
size_t end_a = str_all.find_last_of(')');
|
||||
if (start_a >= str_all.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string str_args = str_all.substr(start_a, end_a - start_a);
|
||||
// bypass " " between ")" and retval
|
||||
start = end_a + 2;
|
||||
if (start >= str_all.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string str_retval = str_all.substr(start);
|
||||
|
||||
std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
|
||||
TypePtr retval = StringToType(str_retval);
|
||||
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
|
||||
if (retval == nullptr || wrong) {
|
||||
return nullptr;
|
||||
}
|
||||
type = std::make_shared<Function>(args_type, retval);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TypePtr StringToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name.compare("None") == 0) {
|
||||
type = std::make_shared<TypeNone>();
|
||||
} else if (type_name.compare("Ellipsis") == 0) {
|
||||
type = std::make_shared<Ellipsis>();
|
||||
} else if (type_name.compare("TypeType") == 0) {
|
||||
type = std::make_shared<TypeType>();
|
||||
} else if (type_name.compare("SymbolicKeyType") == 0) {
|
||||
type = std::make_shared<SymbolicKeyType>();
|
||||
} else if (type_name.compare("RefKeyType") == 0) {
|
||||
type = std::make_shared<RefKeyType>();
|
||||
} else if (type_name.compare("EnvType") == 0) {
|
||||
type = std::make_shared<EnvType>();
|
||||
} else if (type_name.compare("Number") == 0) {
|
||||
type = std::make_shared<Number>();
|
||||
} else if (type_name.compare("Bool") == 0) {
|
||||
type = std::make_shared<Bool>();
|
||||
} else if (type_name.compare(0, strlen("Int"), "Int") == 0) {
|
||||
type = StringToNumberType<Int>(type_name, "Int");
|
||||
} else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) {
|
||||
type = StringToNumberType<UInt>(type_name, "UInt");
|
||||
} else if (type_name.compare(0, strlen("Float"), "Float") == 0) {
|
||||
type = StringToNumberType<Float>(type_name, "Float");
|
||||
} else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
|
||||
type = TensorStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
||||
type = ListStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
||||
type = TupleStrToType(type_name);
|
||||
} else if (type_name.compare("Slice") == 0) {
|
||||
type = std::make_shared<Slice>();
|
||||
} else if (type_name.compare("Dictionary") == 0) {
|
||||
type = std::make_shared<Dictionary>();
|
||||
} else if (type_name.compare("String") == 0) {
|
||||
type = std::make_shared<String>();
|
||||
} else if (type_name.compare("Problem") == 0) {
|
||||
type = std::make_shared<Problem>();
|
||||
} else if (type_name.compare(0, strlen("Function"), "Function") == 0) {
|
||||
type = FunctionStrToType(type_name);
|
||||
} else {
|
||||
// - unsupported to convert
|
||||
// Class
|
||||
// SymbolicType
|
||||
// JTagged
|
||||
// Anything
|
||||
// External
|
||||
// Problem
|
||||
MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!";
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
||||
if (x == nullptr || base_type == nullptr) {
|
||||
MS_LOG(ERROR) << "Type is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
|
||||
return false;
|
||||
} else if (!(base_type->IsGeneric())) {
|
||||
return *(base_type) == *(x);
|
||||
} else if (base_type->type_id() == x->type_id()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->generic_type_id()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->object_type()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->meta_type()) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
if (t1->type_id() == kTypeUnknown) {
|
||||
return false;
|
||||
} else if (t2 != nullptr) {
|
||||
return IsIdentidityOrSubclass(t1, t2);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
typing, ([](py::module *const m) {
|
||||
auto m_sub = m->def_submodule("typing", "submodule for dtype");
|
||||
py::enum_<TypeId>(m_sub, "TypeId");
|
||||
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
|
||||
(void)m_sub.def("load_type", &TypeIdToType, "load type");
|
||||
(void)m_sub.def(
|
||||
"dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
|
||||
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
|
||||
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
|
||||
.def("__eq__",
|
||||
[](const TypePtr &t1, const TypePtr &t2) {
|
||||
if (t1 != nullptr && t2 != nullptr) {
|
||||
return *t1 == *t2;
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.def("__hash__", &Type::hash)
|
||||
.def("__str__", &Type::ToString)
|
||||
.def("__repr__", &Type::ReprString)
|
||||
.def("__deepcopy__", [](const TypePtr &t, py::dict) {
|
||||
if (t == nullptr) {
|
||||
return static_cast<TypePtr>(nullptr);
|
||||
}
|
||||
return t->DeepCopy();
|
||||
});
|
||||
(void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init());
|
||||
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
|
||||
.def(py::init())
|
||||
.def(py::pickle(
|
||||
[](const Bool &) { // __getstate__
|
||||
return py::make_tuple();
|
||||
},
|
||||
[](const py::tuple &) { // __setstate__
|
||||
return std::make_shared<Bool>();
|
||||
}));
|
||||
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
[](const Int &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(t.nbits()));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
Int data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
[](const UInt &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(t.nbits()));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
UInt data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
[](const Float &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(t.nbits()));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
Float data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<List, Type, std::shared_ptr<List>>(m_sub, "List")
|
||||
.def(py::init())
|
||||
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
|
||||
(void)py::class_<Tuple, Type, std::shared_ptr<Tuple>>(m_sub, "Tuple")
|
||||
.def(py::init())
|
||||
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
|
||||
(void)py::class_<TensorType, Type, std::shared_ptr<TensorType>>(m_sub, "TensorType")
|
||||
.def(py::init())
|
||||
.def(py::init<TypePtr>(), py::arg("element"))
|
||||
.def("element_type", &TensorType::element)
|
||||
.def(py::pickle(
|
||||
[](const TensorType &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
||||
.def(py::init())
|
||||
.def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval"));
|
||||
(void)py::class_<Class, Type, std::shared_ptr<Class>>(m_sub, "Class").def(py::init());
|
||||
(void)py::class_<SymbolicKeyType, Type, std::shared_ptr<SymbolicKeyType>>(m_sub, "SymbolicKeyType").def(py::init());
|
||||
(void)py::class_<EnvType, Type, std::shared_ptr<EnvType>>(m_sub, "EnvType").def(py::init());
|
||||
(void)py::class_<TypeNone, Type, std::shared_ptr<TypeNone>>(m_sub, "TypeNone").def(py::init());
|
||||
(void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init());
|
||||
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
|
||||
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
|
||||
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
|
||||
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
|
||||
}));
|
||||
|
||||
const TypePtr kTypeExternal = std::make_shared<External>();
|
||||
const TypePtr kTypeEnv = std::make_shared<EnvType>();
|
||||
const TypePtr kTypeType = std::make_shared<TypeType>();
|
||||
const TypePtr kTensorType = std::make_shared<TensorType>();
|
||||
const TypePtr kString = std::make_shared<String>();
|
||||
} // namespace mindspore
|
|
@ -19,6 +19,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "ir/manager.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "operator/ops.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/profile.h"
|
||||
|
@ -69,7 +70,9 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target,
|
|||
new_param->set_abstract(old_param->abstract());
|
||||
new_param->set_name(old_param->name());
|
||||
if (old_param->has_default()) {
|
||||
new_param->set_default_param(old_param->default_param());
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
|
||||
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
|
||||
new_param->set_default_param(param_value_new);
|
||||
}
|
||||
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
|
||||
new_param->set_scope(scope);
|
||||
|
@ -248,7 +251,9 @@ void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) {
|
|||
if (node->isa<Parameter>()) {
|
||||
ParameterPtr old_param = dyn_cast<Parameter>(node);
|
||||
if (old_param->has_default()) {
|
||||
param->set_default_param(old_param->default_param());
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
|
||||
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
|
||||
param->set_default_param(param_value_new);
|
||||
}
|
||||
param->set_name(old_param->name());
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
class Cloner;
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "ir/dtype.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/signature.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include "ir/base.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_
|
||||
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ParamValueMinnie : public ParamValue {
|
||||
public:
|
||||
ParamValueMinnie() : tensor_addr_(nullptr), tensor_size_(0) {}
|
||||
virtual ~ParamValueMinnie() = default;
|
||||
|
||||
size_t tensor_size() const { return tensor_size_; }
|
||||
void set_tensor_size(size_t size) { tensor_size_ = size; }
|
||||
|
||||
void *tensor_addr() const { return tensor_addr_; }
|
||||
void set_tensor_addr(void *addr) { tensor_addr_ = addr; }
|
||||
|
||||
private:
|
||||
void *tensor_addr_;
|
||||
size_t tensor_size_;
|
||||
};
|
||||
|
||||
using ParamValueMinniePtr = std::shared_ptr<ParamValueMinnie>;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
||||
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
||||
class ParamValuePy : public ParamValue {
|
||||
public:
|
||||
ParamValuePy() : value_(py::none()) {}
|
||||
explicit ParamValuePy(py::object value) : value_(value) {}
|
||||
virtual ~ParamValuePy() = default;
|
||||
|
||||
py::object value() { return value_; }
|
||||
void set_value(const py::object &obj) { value_ = obj; }
|
||||
|
||||
private:
|
||||
py::object value_;
|
||||
};
|
||||
|
||||
using ParamValuePyPtr = std::shared_ptr<ParamValuePy>;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
|
@ -17,20 +17,23 @@
|
|||
#ifndef MINDSPORE_CCSRC_IR_SCALAR_H_
|
||||
#define MINDSPORE_CCSRC_IR_SCALAR_H_
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support inference engine */
|
||||
|
||||
#include <type_traits>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <cfloat>
|
||||
|
||||
#include "ir/base.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/dtype/number.h"
|
||||
|
||||
using std::fabs;
|
||||
|
||||
namespace mindspore {
|
||||
class Scalar : public Value {
|
||||
public:
|
||||
Scalar() = default;
|
||||
|
|
|
@ -19,9 +19,7 @@
|
|||
#include <memory>
|
||||
#include <cmath>
|
||||
#include <cfloat>
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const {
|
||||
|
@ -208,41 +206,6 @@ bool AnyValue::operator==(const Value &other) const {
|
|||
}
|
||||
}
|
||||
const ValuePtr kAnyValue = std::make_shared<AnyValue>();
|
||||
using ContextPtr = abstract::AnalysisContextPtr;
|
||||
|
||||
abstract::AbstractBasePtr Scalar::ToAbstract() {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>());
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr StringImm::ToAbstract() {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>(), std::make_shared<String>());
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr RefKey::ToAbstract() {
|
||||
auto refkey = std::make_shared<abstract::AbstractRefKey>();
|
||||
refkey->set_value(shared_from_base<Value>());
|
||||
return refkey;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); }
|
||||
|
||||
abstract::AbstractBasePtr ValueTuple::ToAbstract() {
|
||||
abstract::AbstractBasePtrList a_list;
|
||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
|
||||
MS_EXCEPTION_IF_NULL(ele);
|
||||
return ele->ToAbstract();
|
||||
});
|
||||
return std::make_shared<abstract::AbstractTuple>(a_list);
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr ValueList::ToAbstract() {
|
||||
abstract::AbstractBasePtrList a_list;
|
||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
|
||||
MS_EXCEPTION_IF_NULL(ele);
|
||||
return ele->ToAbstract();
|
||||
});
|
||||
return std::make_shared<abstract::AbstractList>(a_list);
|
||||
}
|
||||
|
||||
std::size_t ValueSlice::hash() const {
|
||||
MS_EXCEPTION_IF_NULL(start_);
|
||||
|
@ -280,16 +243,6 @@ std::string ValueSlice::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr ValueSlice::ToAbstract() {
|
||||
MS_EXCEPTION_IF_NULL(start_);
|
||||
MS_EXCEPTION_IF_NULL(stop_);
|
||||
MS_EXCEPTION_IF_NULL(step_);
|
||||
abstract::AbstractBasePtr start = start_->ToAbstract();
|
||||
abstract::AbstractBasePtr end = stop_->ToAbstract();
|
||||
abstract::AbstractBasePtr step = step_->ToAbstract();
|
||||
return std::make_shared<abstract::AbstractSlice>(start, end, step);
|
||||
}
|
||||
|
||||
std::size_t KeywordArg::hash() const {
|
||||
MS_EXCEPTION_IF_NULL(value_);
|
||||
return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()});
|
||||
|
@ -316,12 +269,6 @@ std::string KeywordArg::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr KeywordArg::ToAbstract() {
|
||||
MS_EXCEPTION_IF_NULL(value_);
|
||||
abstract::AbstractBasePtr argument = value_->ToAbstract();
|
||||
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
|
||||
}
|
||||
|
||||
const ValuePtr ValueDictionary::operator[](const std::string &key) const {
|
||||
auto it = std::find_if(key_values_.begin(), key_values_.end(),
|
||||
[key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; });
|
||||
|
@ -354,17 +301,4 @@ bool ValueDictionary::operator==(const ValueDictionary &other) const {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr ValueDictionary::ToAbstract() {
|
||||
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
|
||||
(void)std::transform(
|
||||
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
|
||||
[](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); });
|
||||
return std::make_shared<abstract::AbstractDictionary>(kv);
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
RefKey, ([](const py::module *m) {
|
||||
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
|
||||
}));
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/scalar.h"
|
||||
#include "ir/dtype/ref.h"
|
||||
#include "utils/hashing.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/value.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <cmath>
|
||||
#include <cfloat>
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
using ContextPtr = abstract::AnalysisContextPtr;
|
||||
|
||||
abstract::AbstractBasePtr Scalar::ToAbstract() {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>());
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr StringImm::ToAbstract() {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>(), std::make_shared<String>());
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr RefKey::ToAbstract() {
|
||||
auto refkey = std::make_shared<abstract::AbstractRefKey>();
|
||||
refkey->set_value(shared_from_base<Value>());
|
||||
return refkey;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); }
|
||||
|
||||
abstract::AbstractBasePtr ValueTuple::ToAbstract() {
|
||||
abstract::AbstractBasePtrList a_list;
|
||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
|
||||
MS_EXCEPTION_IF_NULL(ele);
|
||||
return ele->ToAbstract();
|
||||
});
|
||||
return std::make_shared<abstract::AbstractTuple>(a_list);
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr ValueList::ToAbstract() {
|
||||
abstract::AbstractBasePtrList a_list;
|
||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
|
||||
MS_EXCEPTION_IF_NULL(ele);
|
||||
return ele->ToAbstract();
|
||||
});
|
||||
return std::make_shared<abstract::AbstractList>(a_list);
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr ValueSlice::ToAbstract() {
|
||||
MS_EXCEPTION_IF_NULL(start_);
|
||||
MS_EXCEPTION_IF_NULL(stop_);
|
||||
MS_EXCEPTION_IF_NULL(step_);
|
||||
abstract::AbstractBasePtr start = start_->ToAbstract();
|
||||
abstract::AbstractBasePtr end = stop_->ToAbstract();
|
||||
abstract::AbstractBasePtr step = step_->ToAbstract();
|
||||
return std::make_shared<abstract::AbstractSlice>(start, end, step);
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr KeywordArg::ToAbstract() {
|
||||
MS_EXCEPTION_IF_NULL(value_);
|
||||
abstract::AbstractBasePtr argument = value_->ToAbstract();
|
||||
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr ValueDictionary::ToAbstract() {
|
||||
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
|
||||
(void)std::transform(
|
||||
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
|
||||
[](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); });
|
||||
return std::make_shared<abstract::AbstractDictionary>(kv);
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
RefKey, ([](const py::module *m) {
|
||||
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
|
||||
}));
|
||||
} // namespace mindspore
|
|
@ -26,6 +26,7 @@
|
|||
#include "debug/anf_ir_utils.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/param_value_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
enum OpMergeMode {
|
||||
|
@ -424,7 +425,8 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP
|
|||
initializer_proto->set_name(param_ptr->ToString());
|
||||
SetTensorProtoInfo(param_ptr, initializer_proto);
|
||||
// set value for initializer
|
||||
py::object obj = param_ptr->default_param();
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
|
||||
py::object obj = param_value->value();
|
||||
py::object data = obj.attr("data");
|
||||
if (py::isinstance<tensor::Tensor>(data)) {
|
||||
auto method = data.attr("asnumpy");
|
||||
|
|
|
@ -18,9 +18,10 @@
|
|||
#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_
|
||||
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
py::dict GetParameterLayout(const FuncGraphPtr &graph);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -37,7 +38,8 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
|
|||
if (!para_ptr->has_default()) {
|
||||
return false;
|
||||
}
|
||||
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(para_ptr->default_param(), "requires_grad"));
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(para_ptr->default_param());
|
||||
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "ir/meta_tensor.h"
|
||||
#include "optimizer/opt.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
@ -190,8 +191,8 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
|
|||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
if (input_parameter->has_default()) {
|
||||
bool require_grad =
|
||||
py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param());
|
||||
bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
|
||||
is_parameter.push_back(require_grad);
|
||||
} else {
|
||||
is_parameter.push_back(false);
|
||||
|
@ -835,8 +836,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(casted_target_parameter);
|
||||
if (casted_target_parameter->has_default()) {
|
||||
bool require_grad = py::cast<bool>(
|
||||
parse::python_adapter::GetPyObjAttr(casted_target_parameter->default_param(), "requires_grad"));
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(casted_target_parameter->default_param());
|
||||
bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
|
||||
is_parameter.push_back(require_grad);
|
||||
} else {
|
||||
is_parameter.push_back(false);
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <utility>
|
||||
|
||||
#include "ir/meta_tensor.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "parallel/auto_parallel/graph_costmodel.h"
|
||||
|
@ -1292,7 +1293,8 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_nod
|
|||
return false;
|
||||
}
|
||||
|
||||
py::object clone_info = parse::python_adapter::GetPyObjAttr(cloned_parameter->default_param(), CLONE_INFO);
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(cloned_parameter->default_param());
|
||||
py::object clone_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO);
|
||||
bool cloned = py::cast<bool>(parse::python_adapter::GetPyObjAttr(clone_info, CLONED));
|
||||
if (!cloned) {
|
||||
return false;
|
||||
|
@ -1314,7 +1316,8 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
}
|
||||
|
||||
// get the cloned index
|
||||
py::object cloned_info = parse::python_adapter::GetPyObjAttr(cloned_parameter->default_param(), CLONE_INFO);
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(cloned_parameter->default_param());
|
||||
py::object cloned_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO);
|
||||
int32_t cloned_index = py::cast<int32_t>(parse::python_adapter::GetPyObjAttr(cloned_info, CLONED_INDEX));
|
||||
|
||||
// find the be cloned parameter
|
||||
|
@ -1329,7 +1332,8 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
continue;
|
||||
}
|
||||
|
||||
py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(be_cloned_parameter->default_param(), CLONE_INFO);
|
||||
auto param_value_cloned = std::dynamic_pointer_cast<ParamValuePy>(be_cloned_parameter->default_param());
|
||||
py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(param_value_cloned->value(), CLONE_INFO);
|
||||
if (!py::cast<bool>(parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED))) {
|
||||
continue;
|
||||
}
|
||||
|
@ -2072,9 +2076,9 @@ std::string NodeParameterName(const CNodePtr &node) {
|
|||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
if (input_parameter->has_default()) {
|
||||
if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) {
|
||||
return py::cast<std::string>(
|
||||
parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME));
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param());
|
||||
if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), REQUIRES_GRAD))) {
|
||||
return py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), PARAM_NAME));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <functional>
|
||||
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "parallel/costmodel_context.h"
|
||||
#include "parallel/context.h"
|
||||
#include "pipeline/pass.h"
|
||||
|
@ -225,8 +226,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
for (const auto ¶m : func_graph->parameters()) {
|
||||
auto param_node = std::static_pointer_cast<Parameter>(param);
|
||||
if (param_node->has_default()) {
|
||||
AbstractBasePtr ptr =
|
||||
abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true);
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
|
||||
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
|
||||
|
||||
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
||||
args_spec.push_back(ptr);
|
||||
|
|
|
@ -25,8 +25,11 @@
|
|||
#include "operator/ops.h"
|
||||
#include "debug/info.h"
|
||||
#include "debug/trace.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace parse {
|
||||
FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
|
||||
func_graph_ = std::make_shared<FuncGraph>();
|
||||
|
|
|
@ -18,11 +18,13 @@
|
|||
#define PIPELINE_PARSE_PARSE_BASE_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/manager.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
// define the node type
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ir/param_value_py.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
#include "pipeline/parse/parse.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
|
@ -101,8 +102,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
|
|||
}
|
||||
}
|
||||
if (para_node == nullptr) {
|
||||
ParameterPtr node = top_graph->AddWeightParameter(param_name);
|
||||
node->set_default_param(obj);
|
||||
auto node = top_graph->AddWeightParameter(param_name);
|
||||
auto param_value_new = std::make_shared<ParamValuePy>(obj);
|
||||
node->set_default_param(param_value_new);
|
||||
|
||||
// set_abstract for parameter
|
||||
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ir/param_value_py.h"
|
||||
#include "pipeline/pass.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
#include "optimizer/ad/dfunctor.h"
|
||||
|
@ -619,7 +620,12 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
|
|||
// maybe some default parameter
|
||||
for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
|
||||
MS_EXCEPTION_IF_NULL(graph_params[i]);
|
||||
py::object obj = dyn_cast<Parameter>(graph_params[i])->default_param();
|
||||
auto param_ptr = (graph_params[i])->cast<ParameterPtr>();
|
||||
if (!param_ptr->has_default()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
|
||||
}
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
|
||||
py::object obj = param_value->value();
|
||||
py::object p_value = py::cast<py::object>(parse::python_adapter::GetPyObjAttr(obj, "default_input"));
|
||||
(*arg_list).push_back(p_value);
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <unordered_set>
|
||||
#include "common/utils.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
@ -232,7 +233,9 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) {
|
|||
new_parameter->set_abstract(parameter->abstract());
|
||||
new_parameter->set_name(parameter->name());
|
||||
if (AnfAlgo::IsParameterWeight(parameter)) {
|
||||
new_parameter->set_default_param(parameter->default_param());
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
|
||||
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
|
||||
new_parameter->set_default_param(param_value_new);
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
} else {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <unordered_set>
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "operator/ops.h"
|
||||
#include "common/trans.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
|
@ -44,10 +45,11 @@ tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) {
|
|||
return nullptr;
|
||||
}
|
||||
auto parameter = node->cast<ParameterPtr>();
|
||||
if (parameter == nullptr) {
|
||||
if (parameter == nullptr || !parameter->has_default()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto py_param = parameter->default_param();
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
|
||||
auto py_param = param_value->value();
|
||||
if (!py::hasattr(py_param, "default_input")) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -315,7 +317,8 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
|
|||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (tensor_mask == 1) {
|
||||
py::object obj;
|
||||
param->set_default_param(obj);
|
||||
auto param_value_new = std::make_shared<ParamValuePy>(obj);
|
||||
param->set_default_param(param_value_new);
|
||||
}
|
||||
// set the kernel info of parameter
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
|
|
|
@ -25,9 +25,11 @@
|
|||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <iterator>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
class BaseRef;
|
||||
class VectorRef;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "utils/callbacks_ge.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "transform/df_graph_manager.h"
|
||||
#include "transform/util.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
|
@ -49,7 +50,11 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name,
|
|||
return false;
|
||||
}
|
||||
if (param_node->name() == param_name) {
|
||||
py::object parameter = param_node->default_param();
|
||||
py::object parameter;
|
||||
if (param_node->has_default()) {
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
|
||||
parameter = param_value->value();
|
||||
}
|
||||
ValuePtr value = parse::data_converter::PyDataToValue(parameter);
|
||||
TensorPtr tensor = std::dynamic_pointer_cast<tensor::Tensor>(value);
|
||||
if (tensor == nullptr) {
|
||||
|
|
|
@ -19,7 +19,9 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "pipeline/parse/parse_base.h"
|
||||
#include "pipeline/parse/parse.h"
|
||||
#include "./common.h"
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/dtype/ref.h"
|
||||
#include "ir/dtype/number.h"
|
||||
#include "ir/dtype/container.h"
|
||||
#include "ir/dtype/empty.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestDType : public UT::Common {
|
||||
|
|
|
@ -92,20 +92,21 @@ class TestTensor : public UT::Common {
|
|||
TestTensor() {}
|
||||
virtual void SetUp() {
|
||||
UT::InitPythonPath();
|
||||
// Init tensor data by py::array_t<float>
|
||||
input_ = py::array_t<float, py::array::c_style>({2, 3});
|
||||
auto array = input_.mutable_unchecked();
|
||||
float start = 0;
|
||||
for (int i = 0; i < array.shape(0); i++) {
|
||||
for (int j = 0; j < array.shape(1); j++) {
|
||||
array(i, j) = start++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
py::array_t<float, py::array::c_style> BuildInputTensor() {
|
||||
// Init tensor data by py::array_t<float>
|
||||
py::array_t<float, py::array::c_style> input = py::array_t<float, py::array::c_style>({2, 3});
|
||||
auto array = input.mutable_unchecked();
|
||||
float start = 0;
|
||||
for (int i = 0; i < array.shape(0); i++) {
|
||||
for (int j = 0; j < array.shape(1); j++) {
|
||||
array(i, j) = start++;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
py::array_t<float, py::array::c_style> input_;
|
||||
};
|
||||
return input;
|
||||
}
|
||||
|
||||
TEST_F(TestTensor, PyArrayScalarTest) {
|
||||
std::vector<int> dimensions;
|
||||
|
@ -246,7 +247,7 @@ TEST_F(TestTensor, PyArrayTest) {
|
|||
|
||||
TEST_F(TestTensor, InitByFloatArrayDataCTest) {
|
||||
// Init tensor data by py::array_t<float>
|
||||
TensorPtr tensor = std::make_shared<Tensor>(input_);
|
||||
auto tensor = std::make_shared<Tensor>(BuildInputTensor());
|
||||
|
||||
// Print some information of the tensor
|
||||
std::cout << "Datatype: " << tensor->data_type() << std::endl;
|
||||
|
@ -268,7 +269,7 @@ TEST_F(TestTensor, InitByFloatArrayDataCTest) {
|
|||
|
||||
TEST_F(TestTensor, InitByFloatArrayDataTest) {
|
||||
// Init tensor data by py::array_t<float>
|
||||
TensorPtr tensor = std::make_shared<Tensor>(input_);
|
||||
TensorPtr tensor = std::make_shared<Tensor>(BuildInputTensor());
|
||||
|
||||
// Print some information of the tensor
|
||||
std::cout << "Datatype: " << tensor->data_type() << std::endl;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ir/dtype/number.h"
|
||||
#include "parallel/device_manager.h"
|
||||
#include "parallel/auto_parallel/edge_costmodel.h"
|
||||
#include "parallel/ops_info/matmul_info.h"
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "parallel/tensor_layout/util_layout_gen_test.h"
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
@ -23,6 +24,8 @@
|
|||
#include "parallel/tensor_layout/shape_util.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
using std::pow;
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
std::vector<std::vector<int32_t>> combine(const std::vector<int32_t>& in, int32_t target) {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "operator/ops.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
@ -765,7 +766,8 @@ TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) {
|
|||
py::object obj;
|
||||
auto parameter_node = kernel_graph->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter_node);
|
||||
parameter_node->set_default_param(obj);
|
||||
auto param_value_new = std::make_shared<ParamValuePy>(obj);
|
||||
parameter_node->set_default_param(param_value_new);
|
||||
EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node));
|
||||
EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error);
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "operator/ops.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
@ -82,7 +83,8 @@ TEST_F(KernelGraphTest, NewParameter) {
|
|||
auto weight_parameter_node = anf_graph->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(weight_parameter_node);
|
||||
py::object obj;
|
||||
weight_parameter_node->set_default_param(obj);
|
||||
auto param_value_new = std::make_shared<ParamValuePy>(obj);
|
||||
weight_parameter_node->set_default_param(param_value_new);
|
||||
weight_parameter_node->set_abstract(x_abstract);
|
||||
auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node);
|
||||
EXPECT_NE(new_weight_parameter_node, nullptr);
|
||||
|
|
Loading…
Reference in New Issue