!1189 Decoupling py default param from Parameter

Merge pull request !1189 from leopz/master
This commit is contained in:
mindspore-ci-bot 2020-05-20 10:02:01 +08:00 committed by Gitee
commit 04ac611fe8
48 changed files with 943 additions and 698 deletions

View File

@ -26,6 +26,7 @@
#include "utils/graph_utils.h"
#include "utils/symbolic.h"
#include "ir/meta_func_graph.h"
#include "ir/param_value_py.h"
#include "pipeline/parse/python_adapter.h"
#include "pipeline/parse/resolve.h"
#include "operator/composite/composite.h"
@ -469,7 +470,8 @@ void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNode
MS_LOG(EXCEPTION) << "Param could not cast to parameter";
}
if (param_ptr->has_default()) {
ofs << " = @" << DumpObject(param_ptr->default_param(), "D");
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
ofs << " = @" << DumpObject(param_value->value(), "D");
}
// output comment
@ -1650,7 +1652,8 @@ class IrParser {
// load parameter default value from serialized file
py::object default_obj = LoadObject(lexer_.GetTokenText());
param->set_default_param(default_obj);
auto param_value_new = std::make_shared<ParamValuePy>(default_obj);
param->set_default_param(param_value_new);
tok = lexer_.GetNextToken();
}

View File

@ -21,12 +21,17 @@
#include <vector>
#include <string>
#include "pybind11/pybind11.h"
#include "ir/meta_func_graph.h"
#include "ir/param_value_py.h"
#include "ir/primitive.h"
#include "utils/graph_utils.h"
#include "utils/utils.h"
#include "operator/composite/composite.h"
#include "ir/meta_tensor.h"
namespace py = pybind11;
namespace mindspore {
// namespace to support debug utils
@ -312,17 +317,21 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
for (auto &parameter : key->parameters()) {
buffer_ << "<tr><td>";
buffer_ << parameter->ToString();
auto py_p = dyn_cast<Parameter>(parameter)->default_param();
if (py::hasattr(py_p, "default_input")) {
py_p = py_p.attr("default_input");
if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) {
std::shared_ptr<tensor::Tensor> m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>();
py::tuple shape = m_tensor->GetPyTupleShape();
buffer_ << "[" << std::string(py::str(shape)) << "]";
} else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) {
std::shared_ptr<tensor::MetaTensor> m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>();
py::tuple shape = m_tensor->GetPyTupleShape();
buffer_ << "[" << std::string(py::str(shape)) << "]";
auto param = parameter->cast<ParameterPtr>();
if (param->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
auto py_p = param_value->value();
if (py::hasattr(py_p, "default_input")) {
py_p = py_p.attr("default_input");
if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) {
auto m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>();
py::tuple shape = m_tensor->GetPyTupleShape();
buffer_ << "[" << py::str(shape) << "]";
} else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) {
auto m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>();
py::tuple shape = m_tensor->GetPyTupleShape();
buffer_ << "[" << py::str(shape) << "]";
}
}
}
buffer_ << "</td></tr>";

View File

@ -18,9 +18,9 @@
#include <utility>
#include <fstream>
#include <sstream>
#include <climits>
#include "ir/anf.h"
#include "pipeline/parse/parse.h"
#include "pipeline/parse/python_adapter.h"
#include "utils/convert_utils.h"
namespace mindspore {
std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {

View File

@ -24,13 +24,10 @@
#include <utility>
#include <vector>
#include "pybind11/pybind11.h"
#include "ir/base.h"
#include "debug/trace_info.h"
namespace mindspore {
namespace py = pybind11;
// namespace to support intermediate representation definition
enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 };

View File

@ -21,7 +21,6 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "utils/any.h"
#include "ir/anf.h"
namespace mindspore {

View File

@ -19,8 +19,6 @@
#include <fstream>
#include <sstream>
#include "ir/anf.h"
#include "pipeline/parse/parse.h"
#include "pipeline/parse/python_adapter.h"
namespace mindspore {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {

View File

@ -24,12 +24,9 @@
#include <utility>
#include <vector>
#include "pybind11/pybind11.h"
#include "ir/base.h"
namespace mindspore {
namespace py = pybind11;
class TraceInfo;
using TraceInfoPtr = std::shared_ptr<TraceInfo>;
class Location;

View File

@ -23,21 +23,11 @@
#include <vector>
#include <unordered_map>
#include "ir/visitor.h"
#include "pipeline/static_analysis/static_analysis.h"
#include "operator/ops.h"
#include "parallel/ops_info/ops_utils.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore {
// namespace to support intermediate representation definition
// Methods of AnfNode
TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); }
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
}
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
@ -85,66 +75,6 @@ std::string CNode::DebugString(int recursive_level) const {
return buffer.str();
}
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name();
auto old_ptr = operator_info_;
operator_info_ = operator_info;
return old_ptr;
}
operator_info_ = operator_info;
return nullptr;
}
std::string CNode::fullname_with_scope() {
// if full name is set, return its name immediately
if (!fullname_with_scope_.empty()) {
return fullname_with_scope_;
}
if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
IsApply(prim::kPrimHistogramSummary)) {
std::string tag = GetValue<std::string>(GetValueNode(input(1)));
if (tag == "") {
MS_LOG(EXCEPTION) << "The tag name is null, should be valid string";
}
std::string name;
if (IsApply(prim::kPrimScalarSummary)) {
name = tag + "[:Scalar]";
} else if (IsApply(prim::kPrimImageSummary)) {
name = tag + "[:Image]";
} else if (IsApply(prim::kPrimHistogramSummary)) {
name = tag + "[:Histogram]";
} else {
name = tag + "[:Tensor]";
}
fullname_with_scope_ = name;
} else {
// cnode input 0 should be primitive ptr
auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
auto input_value = value_ptr->value();
if (input_value == nullptr) {
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
MS_EXCEPTION_IF_NULL(scope());
MS_EXCEPTION_IF_NULL(prim);
fullname_with_scope_ =
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
}
return fullname_with_scope_;
}
std::string ValueNode::ToString() const {
MS_EXCEPTION_IF_NULL(value_);
if (value_->isa<FuncGraph>()) {
@ -173,10 +103,6 @@ std::string ValueNode::fullname_with_scope() {
return fullname_with_scope_;
}
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();

View File

@ -52,6 +52,7 @@ class AbstractBase;
} // namespace abstract
using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
class ValueNode;
using ValueNodePtr = std::shared_ptr<ValueNode>;
@ -78,6 +79,13 @@ using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
class AnfVisitor;
class ParamValue {
public:
ParamValue() = default;
virtual ~ParamValue() = default;
};
using ParamValuePtr = std::shared_ptr<ParamValue>;
// AnfNode is the basic class of the IR definition derived from Base.
// Only two types of nodes are derived: CNode and ANode.
// Methods:
@ -239,11 +247,11 @@ class ANode : public AnfNode {
// Parameter represents the parameter inputs of a function. They have no value.
// Attributes:
// default_param_: used to hold the inputting tensor of the model.
// default_param_value_: used to hold the inputting tensor of the model.
class Parameter : public ANode {
public:
explicit Parameter(const FuncGraphPtr &func_graph)
: ANode(func_graph), name_(""), has_default_(false), default_param_(py::none()), tensor_layout_(nullptr) {}
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {}
~Parameter() override = default;
MS_DECLARE_PARENT(Parameter, ANode);
@ -254,12 +262,11 @@ class Parameter : public ANode {
std::string fullname_with_scope() override { return name(); };
bool has_default() const { return has_default_; }
py::object default_param() { return default_param_; }
void set_default_param(const py::object &obj) {
default_param_ = obj;
void set_default_param(ParamValuePtr param) {
default_param_ = param;
has_default_ = true;
}
ParamValuePtr default_param() const { return default_param_; }
std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; }
void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) {
@ -280,7 +287,7 @@ class Parameter : public ANode {
private:
std::string name_;
bool has_default_;
py::object default_param_;
ParamValuePtr default_param_;
std::shared_ptr<parallel::TensorLayout> tensor_layout_;
};
using ParameterPtr = std::shared_ptr<Parameter>;

View File

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/anf.h"
#include <algorithm>
#include <sstream>
#include <vector>
#include <unordered_map>
#include "ir/visitor.h"
#include "pipeline/static_analysis/static_analysis.h"
#include "operator/ops.h"
#include "parallel/ops_info/ops_utils.h"
namespace mindspore {
// namespace to support intermediate representation definition
// Methods of AnfNode
TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); }
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
}
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name();
auto old_ptr = operator_info_;
operator_info_ = operator_info;
return old_ptr;
}
operator_info_ = operator_info;
return nullptr;
}
std::string CNode::fullname_with_scope() {
// if full name is set, return its name immediately
if (!fullname_with_scope_.empty()) {
return fullname_with_scope_;
}
if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
IsApply(prim::kPrimHistogramSummary)) {
std::string tag = GetValue<std::string>(GetValueNode(input(1)));
if (tag == "") {
MS_LOG(EXCEPTION) << "The tag name is null, should be valid string";
}
std::string name;
if (IsApply(prim::kPrimScalarSummary)) {
name = tag + "[:Scalar]";
} else if (IsApply(prim::kPrimImageSummary)) {
name = tag + "[:Image]";
} else if (IsApply(prim::kPrimHistogramSummary)) {
name = tag + "[:Histogram]";
} else {
name = tag + "[:Tensor]";
}
fullname_with_scope_ = name;
} else {
// cnode input 0 should be primitive ptr
auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
auto input_value = value_ptr->value();
if (input_value == nullptr) {
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
MS_EXCEPTION_IF_NULL(scope());
MS_EXCEPTION_IF_NULL(prim);
fullname_with_scope_ =
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
}
return fullname_with_scope_;
}
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
} // namespace mindspore

View File

@ -19,9 +19,6 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
TypePtr Keyword::DeepCopy() const {
@ -206,8 +203,6 @@ std::string Function::ToString() const {
return buffer.str();
}
TypePtr TypeAnything::DeepCopy() const { return kAnyType; }
TypePtr JTagged::DeepCopy() const {
MS_EXCEPTION_IF_NULL(subtype_);
if (IsGeneric()) {
@ -247,460 +242,4 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble
os << problem->ToString();
return os;
}
std::size_t TypeHasher::operator()(TypePtr const &type) const {
MS_EXCEPTION_IF_NULL(type);
std::size_t hash = std::hash<size_t>()(type->type_id());
return hash;
}
std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
std::size_t hash_sum = 0;
for (auto &type : type_list) {
auto type_id = static_cast<std::size_t>(type->type_id());
hash_sum = hash_combine(hash_sum, type_id);
}
return hash_sum;
}
bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return t1->type_id() == t2->type_id();
}
bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
if (lhs.size() != rhs.size()) {
return false;
}
std::size_t size = lhs.size();
for (std::size_t i = 0; i < size; ++i) {
MS_EXCEPTION_IF_NULL(lhs[i]);
MS_EXCEPTION_IF_NULL(rhs[i]);
if (*lhs[i] != *rhs[i]) {
return false;
}
}
return true;
}
TypePtr TypeIdToType(TypeId id) {
switch (id) {
case kNumberTypeFloat16:
return kFloat16;
case kNumberTypeFloat:
case kNumberTypeFloat32:
return kFloat32;
case kNumberTypeFloat64:
return kFloat64;
case kNumberTypeInt8:
return kInt8;
case kNumberTypeInt16:
return kInt16;
case kNumberTypeInt32:
return kInt32;
case kNumberTypeInt64:
return kInt64;
case kNumberTypeUInt8:
return kUInt8;
case kNumberTypeUInt16:
return kUInt16;
case kNumberTypeUInt32:
return kUInt32;
case kNumberTypeUInt64:
return kUInt64;
case kNumberTypeBool:
return kBool;
case kMetaTypeExternal:
return kTypeExternal;
case kMetaTypeAnything:
return kAnyType;
case kMetaTypeNone:
return kTypeNone;
case kObjectTypeEnvType:
return kTypeEnv;
case kObjectTypeRefKey:
return kRefKeyType;
case kObjectTypeRef:
return kRefType;
case kTypeUnknown:
return kTypeNone;
default:
MS_LOG(EXCEPTION) << "Not support the type: " << id;
}
}
namespace {
template <typename T>
TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
TypePtr type = nullptr;
if (type_name == num_type_name) {
type = std::make_shared<T>();
} else {
try {
if (num_type_name.size() >= type_name.size()) {
MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name
<< ")";
}
auto bits = std::stoi(type_name.substr(num_type_name.size()));
type = std::make_shared<T>(bits);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what();
}
}
return type;
}
std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
std::vector<TypePtr> types;
if (type_names.length() == 0) {
return types;
}
std::string::size_type start = 0;
std::string::size_type end = type_names.find_first_of(',');
while (end != std::string::npos) {
types.push_back(StringToType(type_names.substr(start, end)));
// Skip ',' to find the next element.
start = end + 1;
end = type_names.find_first_of(',', start);
}
if (start >= type_names.size()) {
MS_LOG(EXCEPTION) << "Type name is empty string.";
}
types.push_back(StringToType(type_names.substr(start)));
return types;
}
TypePtr TensorStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Tensor") {
type = std::make_shared<TensorType>();
} else {
try {
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
type = std::make_shared<TensorType>(element_type);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
}
}
return type;
}
TypePtr ListStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "List") {
type = std::make_shared<List>();
} else {
try {
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) {
return nullptr;
}
type = std::make_shared<List>(element_types);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
}
}
return type;
}
TypePtr TupleStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Tuple") {
type = std::make_shared<Tuple>();
} else {
try {
size_t start = type_name.find_first_of('[') + 1;
size_t end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) {
return nullptr;
}
type = std::make_shared<Tuple>(element_types);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
}
}
return type;
}
TypePtr FunctionStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Function") {
type = std::make_shared<Function>();
} else {
try {
// format: [(para1, para2, para3, ...) retval]
size_t start = type_name.find_first_of('[') + 1;
size_t end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
std::string str_all = type_name.substr(start, end - start);
size_t start_a = str_all.find_first_of('(') + 1;
size_t end_a = str_all.find_last_of(')');
if (start_a >= str_all.size()) {
return nullptr;
}
std::string str_args = str_all.substr(start_a, end_a - start_a);
// bypass " " between ")" and retval
start = end_a + 2;
if (start >= str_all.size()) {
return nullptr;
}
std::string str_retval = str_all.substr(start);
std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
TypePtr retval = StringToType(str_retval);
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
if (retval == nullptr || wrong) {
return nullptr;
}
type = std::make_shared<Function>(args_type, retval);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
}
}
return type;
}
} // namespace
TypePtr StringToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>();
} else if (type_name.compare("Ellipsis") == 0) {
type = std::make_shared<Ellipsis>();
} else if (type_name.compare("TypeType") == 0) {
type = std::make_shared<TypeType>();
} else if (type_name.compare("SymbolicKeyType") == 0) {
type = std::make_shared<SymbolicKeyType>();
} else if (type_name.compare("RefKeyType") == 0) {
type = std::make_shared<RefKeyType>();
} else if (type_name.compare("EnvType") == 0) {
type = std::make_shared<EnvType>();
} else if (type_name.compare("Number") == 0) {
type = std::make_shared<Number>();
} else if (type_name.compare("Bool") == 0) {
type = std::make_shared<Bool>();
} else if (type_name.compare(0, strlen("Int"), "Int") == 0) {
type = StringToNumberType<Int>(type_name, "Int");
} else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) {
type = StringToNumberType<UInt>(type_name, "UInt");
} else if (type_name.compare(0, strlen("Float"), "Float") == 0) {
type = StringToNumberType<Float>(type_name, "Float");
} else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
type = TensorStrToType(type_name);
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
type = ListStrToType(type_name);
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
type = TupleStrToType(type_name);
} else if (type_name.compare("Slice") == 0) {
type = std::make_shared<Slice>();
} else if (type_name.compare("Dictionary") == 0) {
type = std::make_shared<Dictionary>();
} else if (type_name.compare("String") == 0) {
type = std::make_shared<String>();
} else if (type_name.compare("Problem") == 0) {
type = std::make_shared<Problem>();
} else if (type_name.compare(0, strlen("Function"), "Function") == 0) {
type = FunctionStrToType(type_name);
} else {
// - unsupported to convert
// Class
// SymbolicType
// JTagged
// Anything
// External
// Problem
MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!";
}
return type;
}
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
if (x == nullptr || base_type == nullptr) {
MS_LOG(ERROR) << "Type is nullptr.";
return false;
}
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
return false;
} else if (!(base_type->IsGeneric())) {
return *(base_type) == *(x);
} else if (base_type->type_id() == x->type_id()) {
return true;
} else if (base_type->type_id() == x->generic_type_id()) {
return true;
} else if (base_type->type_id() == x->object_type()) {
return true;
} else if (base_type->type_id() == x->meta_type()) {
return true;
} else {
return false;
}
}
bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
MS_EXCEPTION_IF_NULL(t1);
if (t1->type_id() == kTypeUnknown) {
return false;
} else if (t2 != nullptr) {
return IsIdentidityOrSubclass(t1, t2);
} else {
return true;
}
}
REGISTER_PYBIND_DEFINE(
typing, ([](py::module *const m) {
auto m_sub = m->def_submodule("typing", "submodule for dtype");
py::enum_<TypeId>(m_sub, "TypeId");
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
(void)m_sub.def("load_type", &TypeIdToType, "load type");
(void)m_sub.def(
"dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
.def("__eq__",
[](const TypePtr &t1, const TypePtr &t2) {
if (t1 != nullptr && t2 != nullptr) {
return *t1 == *t2;
}
return false;
})
.def("__hash__", &Type::hash)
.def("__str__", &Type::ToString)
.def("__repr__", &Type::ReprString)
.def("__deepcopy__", [](const TypePtr &t, py::dict) {
if (t == nullptr) {
return static_cast<TypePtr>(nullptr);
}
return t->DeepCopy();
});
(void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init());
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
.def(py::init())
.def(py::pickle(
[](const Bool &) { // __getstate__
return py::make_tuple();
},
[](const py::tuple &) { // __setstate__
return std::make_shared<Bool>();
}));
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const Int &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
Int data(t[0].cast<py::int_>());
return data;
}));
(void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt")
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const UInt &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
UInt data(t[0].cast<py::int_>());
return data;
}));
(void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float")
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const Float &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
Float data(t[0].cast<py::int_>());
return data;
}));
(void)py::class_<List, Type, std::shared_ptr<List>>(m_sub, "List")
.def(py::init())
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
(void)py::class_<Tuple, Type, std::shared_ptr<Tuple>>(m_sub, "Tuple")
.def(py::init())
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
(void)py::class_<TensorType, Type, std::shared_ptr<TensorType>>(m_sub, "TensorType")
.def(py::init())
.def(py::init<TypePtr>(), py::arg("element"))
.def("element_type", &TensorType::element)
.def(py::pickle(
[](const TensorType &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
return data;
}));
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
.def(py::init())
.def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval"));
(void)py::class_<Class, Type, std::shared_ptr<Class>>(m_sub, "Class").def(py::init());
(void)py::class_<SymbolicKeyType, Type, std::shared_ptr<SymbolicKeyType>>(m_sub, "SymbolicKeyType").def(py::init());
(void)py::class_<EnvType, Type, std::shared_ptr<EnvType>>(m_sub, "EnvType").def(py::init());
(void)py::class_<TypeNone, Type, std::shared_ptr<TypeNone>>(m_sub, "TypeNone").def(py::init());
(void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init());
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
}));
const TypePtr kTypeExternal = std::make_shared<External>();
const TypePtr kTypeEnv = std::make_shared<EnvType>();
const TypePtr kTypeType = std::make_shared<TypeType>();
const TypePtr kTensorType = std::make_shared<TensorType>();
const TypePtr kString = std::make_shared<String>();
} // namespace mindspore

View File

@ -19,9 +19,6 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {

View File

@ -19,9 +19,6 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
bool Number::operator==(const Type &other) const {

View File

@ -19,9 +19,6 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
TypePtr RefType::DeepCopy() const {

View File

@ -21,9 +21,8 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
#include "ir/dtype/number.h"
#include "utils/convert_utils.h"
namespace mindspore {
TypeId IntBitsToTypeId(const int nbits) {
@ -227,11 +226,6 @@ bool Type::operator==(const Value &other) const {
}
}
abstract::AbstractBasePtr Type::ToAbstract() {
abstract::AbstractBasePtr ptr = std::make_shared<abstract::AbstractType>(shared_from_base<Type>());
return ptr;
}
std::ostream &operator<<(std::ostream &os, const Type &type) {
os << type.ToString();
return os;

View File

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/dtype/type.h"
#include "pipeline/static_analysis/abstract_value.h"
namespace mindspore {
abstract::AbstractBasePtr Type::ToAbstract() {
auto ptr = std::make_shared<abstract::AbstractType>(shared_from_base<Type>());
return ptr;
}
} // namespace mindspore

View File

@ -0,0 +1,484 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/dtype.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
TypePtr TypeAnything::DeepCopy() const { return kAnyType; }
std::size_t TypeHasher::operator()(TypePtr const &type) const {
MS_EXCEPTION_IF_NULL(type);
std::size_t hash = std::hash<size_t>()(type->type_id());
return hash;
}
std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
std::size_t hash_sum = 0;
for (auto &type : type_list) {
auto type_id = static_cast<std::size_t>(type->type_id());
hash_sum = hash_combine(hash_sum, type_id);
}
return hash_sum;
}
bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return t1->type_id() == t2->type_id();
}
bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
if (lhs.size() != rhs.size()) {
return false;
}
std::size_t size = lhs.size();
for (std::size_t i = 0; i < size; ++i) {
MS_EXCEPTION_IF_NULL(lhs[i]);
MS_EXCEPTION_IF_NULL(rhs[i]);
if (*lhs[i] != *rhs[i]) {
return false;
}
}
return true;
}
TypePtr TypeIdToType(TypeId id) {
switch (id) {
case kNumberTypeFloat16:
return kFloat16;
case kNumberTypeFloat:
case kNumberTypeFloat32:
return kFloat32;
case kNumberTypeFloat64:
return kFloat64;
case kNumberTypeInt8:
return kInt8;
case kNumberTypeInt16:
return kInt16;
case kNumberTypeInt32:
return kInt32;
case kNumberTypeInt64:
return kInt64;
case kNumberTypeUInt8:
return kUInt8;
case kNumberTypeUInt16:
return kUInt16;
case kNumberTypeUInt32:
return kUInt32;
case kNumberTypeUInt64:
return kUInt64;
case kNumberTypeBool:
return kBool;
case kMetaTypeExternal:
return kTypeExternal;
case kMetaTypeAnything:
return kAnyType;
case kMetaTypeNone:
return kTypeNone;
case kObjectTypeEnvType:
return kTypeEnv;
case kObjectTypeRefKey:
return kRefKeyType;
case kObjectTypeRef:
return kRefType;
case kTypeUnknown:
return kTypeNone;
default:
MS_LOG(EXCEPTION) << "Not support the type: " << id;
}
}
namespace {
template <typename T>
TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
TypePtr type = nullptr;
if (type_name == num_type_name) {
type = std::make_shared<T>();
} else {
try {
if (num_type_name.size() >= type_name.size()) {
MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name
<< ")";
}
auto bits = std::stoi(type_name.substr(num_type_name.size()));
type = std::make_shared<T>(bits);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what();
}
}
return type;
}
std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
std::vector<TypePtr> types;
if (type_names.length() == 0) {
return types;
}
std::string::size_type start = 0;
std::string::size_type end = type_names.find_first_of(',');
while (end != std::string::npos) {
types.push_back(StringToType(type_names.substr(start, end)));
// Skip ',' to find the next element.
start = end + 1;
end = type_names.find_first_of(',', start);
}
if (start >= type_names.size()) {
MS_LOG(EXCEPTION) << "Type name is empty string.";
}
types.push_back(StringToType(type_names.substr(start)));
return types;
}
TypePtr TensorStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Tensor") {
type = std::make_shared<TensorType>();
} else {
try {
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
type = std::make_shared<TensorType>(element_type);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
}
}
return type;
}
TypePtr ListStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "List") {
type = std::make_shared<List>();
} else {
try {
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) {
return nullptr;
}
type = std::make_shared<List>(element_types);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
}
}
return type;
}
TypePtr TupleStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Tuple") {
type = std::make_shared<Tuple>();
} else {
try {
size_t start = type_name.find_first_of('[') + 1;
size_t end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) {
return nullptr;
}
type = std::make_shared<Tuple>(element_types);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
}
}
return type;
}
TypePtr FunctionStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Function") {
type = std::make_shared<Function>();
} else {
try {
// format: [(para1, para2, para3, ...) retval]
size_t start = type_name.find_first_of('[') + 1;
size_t end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
std::string str_all = type_name.substr(start, end - start);
size_t start_a = str_all.find_first_of('(') + 1;
size_t end_a = str_all.find_last_of(')');
if (start_a >= str_all.size()) {
return nullptr;
}
std::string str_args = str_all.substr(start_a, end_a - start_a);
// bypass " " between ")" and retval
start = end_a + 2;
if (start >= str_all.size()) {
return nullptr;
}
std::string str_retval = str_all.substr(start);
std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
TypePtr retval = StringToType(str_retval);
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
if (retval == nullptr || wrong) {
return nullptr;
}
type = std::make_shared<Function>(args_type, retval);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what();
}
}
return type;
}
} // namespace
TypePtr StringToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>();
} else if (type_name.compare("Ellipsis") == 0) {
type = std::make_shared<Ellipsis>();
} else if (type_name.compare("TypeType") == 0) {
type = std::make_shared<TypeType>();
} else if (type_name.compare("SymbolicKeyType") == 0) {
type = std::make_shared<SymbolicKeyType>();
} else if (type_name.compare("RefKeyType") == 0) {
type = std::make_shared<RefKeyType>();
} else if (type_name.compare("EnvType") == 0) {
type = std::make_shared<EnvType>();
} else if (type_name.compare("Number") == 0) {
type = std::make_shared<Number>();
} else if (type_name.compare("Bool") == 0) {
type = std::make_shared<Bool>();
} else if (type_name.compare(0, strlen("Int"), "Int") == 0) {
type = StringToNumberType<Int>(type_name, "Int");
} else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) {
type = StringToNumberType<UInt>(type_name, "UInt");
} else if (type_name.compare(0, strlen("Float"), "Float") == 0) {
type = StringToNumberType<Float>(type_name, "Float");
} else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
type = TensorStrToType(type_name);
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
type = ListStrToType(type_name);
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
type = TupleStrToType(type_name);
} else if (type_name.compare("Slice") == 0) {
type = std::make_shared<Slice>();
} else if (type_name.compare("Dictionary") == 0) {
type = std::make_shared<Dictionary>();
} else if (type_name.compare("String") == 0) {
type = std::make_shared<String>();
} else if (type_name.compare("Problem") == 0) {
type = std::make_shared<Problem>();
} else if (type_name.compare(0, strlen("Function"), "Function") == 0) {
type = FunctionStrToType(type_name);
} else {
// - unsupported to convert
// Class
// SymbolicType
// JTagged
// Anything
// External
// Problem
MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!";
}
return type;
}
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
if (x == nullptr || base_type == nullptr) {
MS_LOG(ERROR) << "Type is nullptr.";
return false;
}
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
return false;
} else if (!(base_type->IsGeneric())) {
return *(base_type) == *(x);
} else if (base_type->type_id() == x->type_id()) {
return true;
} else if (base_type->type_id() == x->generic_type_id()) {
return true;
} else if (base_type->type_id() == x->object_type()) {
return true;
} else if (base_type->type_id() == x->meta_type()) {
return true;
} else {
return false;
}
}
bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
MS_EXCEPTION_IF_NULL(t1);
if (t1->type_id() == kTypeUnknown) {
return false;
} else if (t2 != nullptr) {
return IsIdentidityOrSubclass(t1, t2);
} else {
return true;
}
}
REGISTER_PYBIND_DEFINE(
typing, ([](py::module *const m) {
auto m_sub = m->def_submodule("typing", "submodule for dtype");
py::enum_<TypeId>(m_sub, "TypeId");
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
(void)m_sub.def("load_type", &TypeIdToType, "load type");
(void)m_sub.def(
"dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
.def("__eq__",
[](const TypePtr &t1, const TypePtr &t2) {
if (t1 != nullptr && t2 != nullptr) {
return *t1 == *t2;
}
return false;
})
.def("__hash__", &Type::hash)
.def("__str__", &Type::ToString)
.def("__repr__", &Type::ReprString)
.def("__deepcopy__", [](const TypePtr &t, py::dict) {
if (t == nullptr) {
return static_cast<TypePtr>(nullptr);
}
return t->DeepCopy();
});
(void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init());
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
.def(py::init())
.def(py::pickle(
[](const Bool &) { // __getstate__
return py::make_tuple();
},
[](const py::tuple &) { // __setstate__
return std::make_shared<Bool>();
}));
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const Int &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
Int data(t[0].cast<py::int_>());
return data;
}));
(void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt")
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const UInt &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
UInt data(t[0].cast<py::int_>());
return data;
}));
(void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float")
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const Float &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
Float data(t[0].cast<py::int_>());
return data;
}));
(void)py::class_<List, Type, std::shared_ptr<List>>(m_sub, "List")
.def(py::init())
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
(void)py::class_<Tuple, Type, std::shared_ptr<Tuple>>(m_sub, "Tuple")
.def(py::init())
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
(void)py::class_<TensorType, Type, std::shared_ptr<TensorType>>(m_sub, "TensorType")
.def(py::init())
.def(py::init<TypePtr>(), py::arg("element"))
.def("element_type", &TensorType::element)
.def(py::pickle(
[](const TensorType &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
return data;
}));
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
.def(py::init())
.def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval"));
(void)py::class_<Class, Type, std::shared_ptr<Class>>(m_sub, "Class").def(py::init());
(void)py::class_<SymbolicKeyType, Type, std::shared_ptr<SymbolicKeyType>>(m_sub, "SymbolicKeyType").def(py::init());
(void)py::class_<EnvType, Type, std::shared_ptr<EnvType>>(m_sub, "EnvType").def(py::init());
(void)py::class_<TypeNone, Type, std::shared_ptr<TypeNone>>(m_sub, "TypeNone").def(py::init());
(void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init());
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
}));
const TypePtr kTypeExternal = std::make_shared<External>();
const TypePtr kTypeEnv = std::make_shared<EnvType>();
const TypePtr kTypeType = std::make_shared<TypeType>();
const TypePtr kTensorType = std::make_shared<TensorType>();
const TypePtr kString = std::make_shared<String>();
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <algorithm>
#include "ir/manager.h"
#include "ir/param_value_py.h"
#include "operator/ops.h"
#include "utils/log_adapter.h"
#include "utils/profile.h"
@ -69,7 +70,9 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target,
new_param->set_abstract(old_param->abstract());
new_param->set_name(old_param->name());
if (old_param->has_default()) {
new_param->set_default_param(old_param->default_param());
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
new_param->set_default_param(param_value_new);
}
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_param->set_scope(scope);
@ -248,7 +251,9 @@ void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
if (node->isa<Parameter>()) {
ParameterPtr old_param = dyn_cast<Parameter>(node);
if (old_param->has_default()) {
param->set_default_param(old_param->default_param());
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
param->set_default_param(param_value_new);
}
param->set_name(old_param->name());
}

View File

@ -28,6 +28,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
namespace mindspore {
class Cloner;

View File

@ -31,6 +31,7 @@
#include "ir/dtype.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/signature.h"
#include "pipeline/static_analysis/abstract_value.h"
namespace py = pybind11;

View File

@ -21,7 +21,6 @@
#include <memory>
#include <functional>
#include "ir/base.h"
#include "ir/anf.h"
namespace mindspore {

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_
#include <memory>
#include "ir/anf.h"
namespace mindspore {
class ParamValueMinnie : public ParamValue {
public:
ParamValueMinnie() : tensor_addr_(nullptr), tensor_size_(0) {}
virtual ~ParamValueMinnie() = default;
size_t tensor_size() const { return tensor_size_; }
void set_tensor_size(size_t size) { tensor_size_ = size; }
void *tensor_addr() const { return tensor_addr_; }
void set_tensor_addr(void *addr) { tensor_addr_ = addr; }
private:
void *tensor_addr_;
size_t tensor_size_;
};
using ParamValueMinniePtr = std::shared_ptr<ParamValueMinnie>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
#include <memory>
#include "ir/anf.h"
#include "pybind11/pybind11.h"
namespace mindspore {
namespace py = pybind11;
class ParamValuePy : public ParamValue {
public:
ParamValuePy() : value_(py::none()) {}
explicit ParamValuePy(py::object value) : value_(value) {}
virtual ~ParamValuePy() = default;
py::object value() { return value_; }
void set_value(const py::object &obj) { value_ = obj; }
private:
py::object value_;
};
using ParamValuePyPtr = std::shared_ptr<ParamValuePy>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_

View File

@ -17,20 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_SCALAR_H_
#define MINDSPORE_CCSRC_IR_SCALAR_H_
namespace mindspore {
/* namespace to support inference engine */
#include <type_traits>
#include <algorithm>
#include <cmath>
#include <vector>
#include <string>
#include <memory>
#include <sstream>
#include <utility>
#include <cfloat>
#include "ir/base.h"
#include "ir/dtype.h"
#include "ir/dtype/number.h"
using std::fabs;
namespace mindspore {
class Scalar : public Value {
public:
Scalar() = default;

View File

@ -19,9 +19,7 @@
#include <memory>
#include <cmath>
#include <cfloat>
#include "pybind_api/api_register.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "utils/convert_utils.h"
namespace mindspore {
const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const {
@ -208,41 +206,6 @@ bool AnyValue::operator==(const Value &other) const {
}
}
const ValuePtr kAnyValue = std::make_shared<AnyValue>();
using ContextPtr = abstract::AnalysisContextPtr;
abstract::AbstractBasePtr Scalar::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>());
}
abstract::AbstractBasePtr StringImm::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>(), std::make_shared<String>());
}
abstract::AbstractBasePtr RefKey::ToAbstract() {
auto refkey = std::make_shared<abstract::AbstractRefKey>();
refkey->set_value(shared_from_base<Value>());
return refkey;
}
abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); }
abstract::AbstractBasePtr ValueTuple::ToAbstract() {
abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract();
});
return std::make_shared<abstract::AbstractTuple>(a_list);
}
abstract::AbstractBasePtr ValueList::ToAbstract() {
abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract();
});
return std::make_shared<abstract::AbstractList>(a_list);
}
std::size_t ValueSlice::hash() const {
MS_EXCEPTION_IF_NULL(start_);
@ -280,16 +243,6 @@ std::string ValueSlice::ToString() const {
return buffer.str();
}
abstract::AbstractBasePtr ValueSlice::ToAbstract() {
MS_EXCEPTION_IF_NULL(start_);
MS_EXCEPTION_IF_NULL(stop_);
MS_EXCEPTION_IF_NULL(step_);
abstract::AbstractBasePtr start = start_->ToAbstract();
abstract::AbstractBasePtr end = stop_->ToAbstract();
abstract::AbstractBasePtr step = step_->ToAbstract();
return std::make_shared<abstract::AbstractSlice>(start, end, step);
}
std::size_t KeywordArg::hash() const {
MS_EXCEPTION_IF_NULL(value_);
return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()});
@ -316,12 +269,6 @@ std::string KeywordArg::ToString() const {
return buffer.str();
}
abstract::AbstractBasePtr KeywordArg::ToAbstract() {
MS_EXCEPTION_IF_NULL(value_);
abstract::AbstractBasePtr argument = value_->ToAbstract();
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
}
const ValuePtr ValueDictionary::operator[](const std::string &key) const {
auto it = std::find_if(key_values_.begin(), key_values_.end(),
[key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; });
@ -354,17 +301,4 @@ bool ValueDictionary::operator==(const ValueDictionary &other) const {
}
return true;
}
abstract::AbstractBasePtr ValueDictionary::ToAbstract() {
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
(void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); });
return std::make_shared<abstract::AbstractDictionary>(kv);
}
REGISTER_PYBIND_DEFINE(
RefKey, ([](const py::module *m) {
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
}));
} // namespace mindspore

View File

@ -29,6 +29,7 @@
#include "ir/anf.h"
#include "ir/dtype.h"
#include "ir/scalar.h"
#include "ir/dtype/ref.h"
#include "utils/hashing.h"
#include "common/utils.h"

View File

@ -0,0 +1,91 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/value.h"
#include <algorithm>
#include <memory>
#include <cmath>
#include <cfloat>
#include "pybind_api/api_register.h"
#include "pipeline/static_analysis/abstract_value.h"
namespace mindspore {
using ContextPtr = abstract::AnalysisContextPtr;
abstract::AbstractBasePtr Scalar::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>());
}
abstract::AbstractBasePtr StringImm::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>(), std::make_shared<String>());
}
abstract::AbstractBasePtr RefKey::ToAbstract() {
auto refkey = std::make_shared<abstract::AbstractRefKey>();
refkey->set_value(shared_from_base<Value>());
return refkey;
}
abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); }
abstract::AbstractBasePtr ValueTuple::ToAbstract() {
abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract();
});
return std::make_shared<abstract::AbstractTuple>(a_list);
}
abstract::AbstractBasePtr ValueList::ToAbstract() {
abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract();
});
return std::make_shared<abstract::AbstractList>(a_list);
}
abstract::AbstractBasePtr ValueSlice::ToAbstract() {
MS_EXCEPTION_IF_NULL(start_);
MS_EXCEPTION_IF_NULL(stop_);
MS_EXCEPTION_IF_NULL(step_);
abstract::AbstractBasePtr start = start_->ToAbstract();
abstract::AbstractBasePtr end = stop_->ToAbstract();
abstract::AbstractBasePtr step = step_->ToAbstract();
return std::make_shared<abstract::AbstractSlice>(start, end, step);
}
abstract::AbstractBasePtr KeywordArg::ToAbstract() {
MS_EXCEPTION_IF_NULL(value_);
abstract::AbstractBasePtr argument = value_->ToAbstract();
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
}
abstract::AbstractBasePtr ValueDictionary::ToAbstract() {
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
(void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); });
return std::make_shared<abstract::AbstractDictionary>(kv);
}
REGISTER_PYBIND_DEFINE(
RefKey, ([](const py::module *m) {
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
}));
} // namespace mindspore

View File

@ -26,6 +26,7 @@
#include "debug/anf_ir_utils.h"
#include "proto/onnx.pb.h"
#include "operator/ops.h"
#include "ir/param_value_py.h"
namespace mindspore {
enum OpMergeMode {
@ -424,7 +425,8 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP
initializer_proto->set_name(param_ptr->ToString());
SetTensorProtoInfo(param_ptr, initializer_proto);
// set value for initializer
py::object obj = param_ptr->default_param();
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
py::object obj = param_value->value();
py::object data = obj.attr("data");
if (py::isinstance<tensor::Tensor>(data)) {
auto method = data.attr("asnumpy");

View File

@ -18,9 +18,10 @@
#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_
#include "pybind11/stl.h"
#include "pybind11/pybind11.h"
#include "ir/anf.h"
namespace py = pybind11;
namespace mindspore {
namespace parallel {
py::dict GetParameterLayout(const FuncGraphPtr &graph);

View File

@ -19,6 +19,7 @@
#include <string>
#include "ir/anf.h"
#include "ir/param_value_py.h"
#include "pipeline/parse/python_adapter.h"
namespace mindspore {
@ -37,7 +38,8 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
if (!para_ptr->has_default()) {
return false;
}
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(para_ptr->default_param(), "requires_grad"));
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(para_ptr->default_param());
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
}
} // namespace parallel
} // namespace mindspore

View File

@ -28,6 +28,7 @@
#include <vector>
#include "ir/anf.h"
#include "ir/param_value_py.h"
#include "ir/meta_tensor.h"
#include "optimizer/opt.h"
#include "optimizer/optimizer.h"
@ -190,8 +191,8 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
bool require_grad =
py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param());
bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
is_parameter.push_back(require_grad);
} else {
is_parameter.push_back(false);
@ -835,8 +836,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(casted_target_parameter);
if (casted_target_parameter->has_default()) {
bool require_grad = py::cast<bool>(
parse::python_adapter::GetPyObjAttr(casted_target_parameter->default_param(), "requires_grad"));
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(casted_target_parameter->default_param());
bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
is_parameter.push_back(require_grad);
} else {
is_parameter.push_back(false);

View File

@ -28,6 +28,7 @@
#include <utility>
#include "ir/meta_tensor.h"
#include "ir/param_value_py.h"
#include "operator/ops.h"
#include "optimizer/optimizer.h"
#include "parallel/auto_parallel/graph_costmodel.h"
@ -1292,7 +1293,8 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr &parameter_nod
return false;
}
py::object clone_info = parse::python_adapter::GetPyObjAttr(cloned_parameter->default_param(), CLONE_INFO);
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(cloned_parameter->default_param());
py::object clone_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO);
bool cloned = py::cast<bool>(parse::python_adapter::GetPyObjAttr(clone_info, CLONED));
if (!cloned) {
return false;
@ -1314,7 +1316,8 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
}
// get the cloned index
py::object cloned_info = parse::python_adapter::GetPyObjAttr(cloned_parameter->default_param(), CLONE_INFO);
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(cloned_parameter->default_param());
py::object cloned_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO);
int32_t cloned_index = py::cast<int32_t>(parse::python_adapter::GetPyObjAttr(cloned_info, CLONED_INDEX));
// find the be cloned parameter
@ -1329,7 +1332,8 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
continue;
}
py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(be_cloned_parameter->default_param(), CLONE_INFO);
auto param_value_cloned = std::dynamic_pointer_cast<ParamValuePy>(be_cloned_parameter->default_param());
py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(param_value_cloned->value(), CLONE_INFO);
if (!py::cast<bool>(parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED))) {
continue;
}
@ -2072,9 +2076,9 @@ std::string NodeParameterName(const CNodePtr &node) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) {
return py::cast<std::string>(
parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME));
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param());
if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), REQUIRES_GRAD))) {
return py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), PARAM_NAME));
}
}
}

View File

@ -24,6 +24,7 @@
#include <functional>
#include "ir/func_graph_cloner.h"
#include "ir/param_value_py.h"
#include "parallel/costmodel_context.h"
#include "parallel/context.h"
#include "pipeline/pass.h"
@ -225,8 +226,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
for (const auto &param : func_graph->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
AbstractBasePtr ptr =
abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true);
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr);

View File

@ -25,8 +25,11 @@
#include "operator/ops.h"
#include "debug/info.h"
#include "debug/trace.h"
#include "pybind11/pybind11.h"
namespace mindspore {
namespace py = pybind11;
namespace parse {
FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
func_graph_ = std::make_shared<FuncGraph>();

View File

@ -18,11 +18,13 @@
#define PIPELINE_PARSE_PARSE_BASE_H_
#include <string>
#include <memory>
#include "pybind11/pybind11.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "pybind_api/export_flags.h"
namespace py = pybind11;
namespace mindspore {
namespace parse {
// define the node type

View File

@ -21,6 +21,7 @@
#include <vector>
#include <algorithm>
#include "ir/param_value_py.h"
#include "pipeline/parse/data_converter.h"
#include "pipeline/parse/parse.h"
#include "pipeline/parse/python_adapter.h"
@ -101,8 +102,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
}
}
if (para_node == nullptr) {
ParameterPtr node = top_graph->AddWeightParameter(param_name);
node->set_default_param(obj);
auto node = top_graph->AddWeightParameter(param_name);
auto param_value_new = std::make_shared<ParamValuePy>(obj);
node->set_default_param(param_value_new);
// set_abstract for parameter
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));

View File

@ -24,6 +24,7 @@
#include <cstdlib>
#include <algorithm>
#include "ir/param_value_py.h"
#include "pipeline/pass.h"
#include "pipeline/parse/data_converter.h"
#include "optimizer/ad/dfunctor.h"
@ -619,7 +620,12 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
// maybe some default parameter
for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
MS_EXCEPTION_IF_NULL(graph_params[i]);
py::object obj = dyn_cast<Parameter>(graph_params[i])->default_param();
auto param_ptr = (graph_params[i])->cast<ParameterPtr>();
if (!param_ptr->has_default()) {
MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
py::object obj = param_value->value();
py::object p_value = py::cast<py::object>(parse::python_adapter::GetPyObjAttr(obj, "default_input"));
(*arg_list).push_back(p_value);
}

5
mindspore/ccsrc/session/kernel_graph.cc Executable file → Normal file
View File

@ -20,6 +20,7 @@
#include <unordered_set>
#include "common/utils.h"
#include "operator/ops.h"
#include "ir/param_value_py.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "kernel/kernel_build_info.h"
@ -232,7 +233,9 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
new_parameter->set_abstract(parameter->abstract());
new_parameter->set_name(parameter->name());
if (AnfAlgo::IsParameterWeight(parameter)) {
new_parameter->set_default_param(parameter->default_param());
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
new_parameter->set_default_param(param_value_new);
kernel_info->SetFeatureMapFlag(false);
} else {
kernel_info->SetFeatureMapFlag(true);

View File

@ -20,6 +20,7 @@
#include <unordered_set>
#include "pipeline/parse/data_converter.h"
#include "ir/manager.h"
#include "ir/param_value_py.h"
#include "operator/ops.h"
#include "common/trans.h"
#include "utils/context/ms_context.h"
@ -44,10 +45,11 @@ tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) {
return nullptr;
}
auto parameter = node->cast<ParameterPtr>();
if (parameter == nullptr) {
if (parameter == nullptr || !parameter->has_default()) {
return nullptr;
}
auto py_param = parameter->default_param();
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
auto py_param = param_value->value();
if (!py::hasattr(py_param, "default_input")) {
return nullptr;
}
@ -319,7 +321,8 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
MS_EXCEPTION_IF_NULL(param);
if (tensor_mask == kParameterWeightTensorMask) {
py::object obj;
param->set_default_param(obj);
auto param_value_new = std::make_shared<ParamValuePy>(obj);
param->set_default_param(param_value_new);
}
// set the kernel info of parameter
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();

View File

@ -25,9 +25,11 @@
#include <sstream>
#include <utility>
#include <iterator>
#include "pybind11/pybind11.h"
#include "ir/value.h"
namespace py = pybind11;
namespace mindspore {
class BaseRef;
class VectorRef;

View File

@ -16,6 +16,7 @@
#include "utils/callbacks_ge.h"
#include "pybind11/pybind11.h"
#include "ir/param_value_py.h"
#include "transform/df_graph_manager.h"
#include "transform/util.h"
#include "pipeline/parse/data_converter.h"
@ -49,7 +50,11 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string &param_name,
return false;
}
if (param_node->name() == param_name) {
py::object parameter = param_node->default_param();
py::object parameter;
if (param_node->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
parameter = param_value->value();
}
ValuePtr value = parse::data_converter::PyDataToValue(parameter);
TensorPtr tensor = std::dynamic_pointer_cast<tensor::Tensor>(value);
if (tensor == nullptr) {

View File

@ -19,7 +19,9 @@
#include <string>
#include <memory>
#include "ir/anf.h"
#include "ir/primitive.h"
#include "ir/manager.h"
#include "ir/func_graph.h"
#include "pipeline/parse/parse_base.h"
#include "pipeline/parse/parse.h"
#include "./common.h"

View File

@ -16,6 +16,10 @@
#include <iostream>
#include "common/common_test.h"
#include "ir/dtype.h"
#include "ir/dtype/ref.h"
#include "ir/dtype/number.h"
#include "ir/dtype/container.h"
#include "ir/dtype/empty.h"
namespace mindspore {
class TestDType : public UT::Common {

View File

@ -92,20 +92,21 @@ class TestTensor : public UT::Common {
TestTensor() {}
virtual void SetUp() {
UT::InitPythonPath();
// Init tensor data by py::array_t<float>
input_ = py::array_t<float, py::array::c_style>({2, 3});
auto array = input_.mutable_unchecked();
float start = 0;
for (int i = 0; i < array.shape(0); i++) {
for (int j = 0; j < array.shape(1); j++) {
array(i, j) = start++;
}
}
};
py::array_t<float, py::array::c_style> BuildInputTensor() {
// Init tensor data by py::array_t<float>
py::array_t<float, py::array::c_style> input = py::array_t<float, py::array::c_style>({2, 3});
auto array = input.mutable_unchecked();
float start = 0;
for (int i = 0; i < array.shape(0); i++) {
for (int j = 0; j < array.shape(1); j++) {
array(i, j) = start++;
}
}
protected:
py::array_t<float, py::array::c_style> input_;
};
return input;
}
TEST_F(TestTensor, PyArrayScalarTest) {
std::vector<int> dimensions;
@ -246,7 +247,7 @@ TEST_F(TestTensor, PyArrayTest) {
TEST_F(TestTensor, InitByFloatArrayDataCTest) {
// Init tensor data by py::array_t<float>
TensorPtr tensor = std::make_shared<Tensor>(input_);
auto tensor = std::make_shared<Tensor>(BuildInputTensor());
// Print some information of the tensor
std::cout << "Datatype: " << tensor->data_type() << std::endl;
@ -268,7 +269,7 @@ TEST_F(TestTensor, InitByFloatArrayDataCTest) {
TEST_F(TestTensor, InitByFloatArrayDataTest) {
// Init tensor data by py::array_t<float>
TensorPtr tensor = std::make_shared<Tensor>(input_);
TensorPtr tensor = std::make_shared<Tensor>(BuildInputTensor());
// Print some information of the tensor
std::cout << "Datatype: " << tensor->data_type() << std::endl;

View File

@ -15,6 +15,7 @@
*/
#include "common/common_test.h"
#include "ir/dtype/number.h"
#include "parallel/device_manager.h"
#include "parallel/auto_parallel/edge_costmodel.h"
#include "parallel/ops_info/matmul_info.h"

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "parallel/tensor_layout/util_layout_gen_test.h"
#include <cmath>
#include <map>
#include <tuple>
#include <vector>
@ -23,6 +24,8 @@
#include "parallel/tensor_layout/shape_util.h"
#include "common/common_test.h"
using std::pow;
namespace mindspore {
namespace parallel {
std::vector<std::vector<int32_t>> combine(const std::vector<int32_t>& in, int32_t target) {

View File

@ -15,6 +15,7 @@
*/
#include "common/common_test.h"
#include "ir/param_value_py.h"
#include "operator/ops.h"
#include "session/kernel_graph.h"
#include "session/anf_runtime_algorithm.h"
@ -765,7 +766,8 @@ TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) {
py::object obj;
auto parameter_node = kernel_graph->add_parameter();
MS_EXCEPTION_IF_NULL(parameter_node);
parameter_node->set_default_param(obj);
auto param_value_new = std::make_shared<ParamValuePy>(obj);
parameter_node->set_default_param(param_value_new);
EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node));
EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error);
}

View File

@ -15,6 +15,7 @@
*/
#include "common/common_test.h"
#include "ir/param_value_py.h"
#include "operator/ops.h"
#include "session/kernel_graph.h"
#include "session/anf_runtime_algorithm.h"
@ -82,7 +83,8 @@ TEST_F(KernelGraphTest, NewParameter) {
auto weight_parameter_node = anf_graph->add_parameter();
MS_EXCEPTION_IF_NULL(weight_parameter_node);
py::object obj;
weight_parameter_node->set_default_param(obj);
auto param_value_new = std::make_shared<ParamValuePy>(obj);
weight_parameter_node->set_default_param(param_value_new);
weight_parameter_node->set_abstract(x_abstract);
auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node);
EXPECT_NE(new_weight_parameter_node, nullptr);