forked from OSSInnovation/mindspore
Add IndexedSlices
This commit is contained in:
parent
d454daec1b
commit
d6635bbbe2
|
@ -17,6 +17,7 @@
|
||||||
"""Resources for ast tree parse."""
|
"""Resources for ast tree parse."""
|
||||||
import ast
|
import ast
|
||||||
import math
|
import math
|
||||||
|
from mindspore import IndexedSlices
|
||||||
from mindspore.ops.composite import multitype_ops
|
from mindspore.ops.composite import multitype_ops
|
||||||
from mindspore.ops import functional as F, composite as C
|
from mindspore.ops import functional as F, composite as C
|
||||||
from . import standard_method as M
|
from . import standard_method as M
|
||||||
|
@ -135,4 +136,7 @@ convert_object_map = {
|
||||||
math.sin: NO_IMPLEMENT,
|
math.sin: NO_IMPLEMENT,
|
||||||
math.cos: NO_IMPLEMENT,
|
math.cos: NO_IMPLEMENT,
|
||||||
math.tan: NO_IMPLEMENT,
|
math.tan: NO_IMPLEMENT,
|
||||||
|
|
||||||
|
# user defined
|
||||||
|
IndexedSlices: F.make_indexed_slices,
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,6 +120,10 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
|
||||||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (type->isa<IndexedSlicesType>()) {
|
||||||
|
// Do Nothing
|
||||||
|
} else if (type->isa<UndeterminedType>()) {
|
||||||
|
// Do Nothing
|
||||||
} else if (type->isa<Tuple>()) {
|
} else if (type->isa<Tuple>()) {
|
||||||
TuplePtr tuple_type = dyn_cast<Tuple>(type);
|
TuplePtr tuple_type = dyn_cast<Tuple>(type);
|
||||||
type_proto->set_data_type(irpb::DT_TUPLE);
|
type_proto->set_data_type(irpb::DT_TUPLE);
|
||||||
|
|
|
@ -94,6 +94,48 @@ bool Slice::operator==(const Type &other) const {
|
||||||
|
|
||||||
std::string Slice::DumpText() const { return ToString(); }
|
std::string Slice::DumpText() const { return ToString(); }
|
||||||
|
|
||||||
|
TypePtr UndeterminedType::DeepCopy() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element_type_);
|
||||||
|
if (IsGeneric()) {
|
||||||
|
return std::make_shared<UndeterminedType>();
|
||||||
|
}
|
||||||
|
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string UndeterminedType::ToReprString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "Undetermined";
|
||||||
|
}
|
||||||
|
return "Undetermined[" + element_type_->ToReprString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string UndeterminedType::ToString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "Undetermined";
|
||||||
|
}
|
||||||
|
return "Undetermined[" + element_type_->ToString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string UndeterminedType::DumpText() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "Undetermined";
|
||||||
|
}
|
||||||
|
return "Undetermined[" + element_type_->DumpText() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UndeterminedType::operator==(const Type &other) const {
|
||||||
|
if (!IsSameObjectType(*this, other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
|
||||||
|
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||||
|
return true;
|
||||||
|
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *element_type_ == *other_elem_type;
|
||||||
|
}
|
||||||
|
|
||||||
TypePtr TensorType::DeepCopy() const {
|
TypePtr TensorType::DeepCopy() const {
|
||||||
MS_EXCEPTION_IF_NULL(element_type_);
|
MS_EXCEPTION_IF_NULL(element_type_);
|
||||||
if (IsGeneric()) {
|
if (IsGeneric()) {
|
||||||
|
@ -137,6 +179,48 @@ bool TensorType::operator==(const Type &other) const {
|
||||||
return *element_type_ == *other_elem_type;
|
return *element_type_ == *other_elem_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypePtr IndexedSlicesType::DeepCopy() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element_type_);
|
||||||
|
if (IsGeneric()) {
|
||||||
|
return std::make_shared<IndexedSlicesType>();
|
||||||
|
}
|
||||||
|
return std::make_shared<IndexedSlicesType>(element_type_->DeepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string IndexedSlicesType::ToReprString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "IndexedSlices";
|
||||||
|
}
|
||||||
|
return "IndexedSlices[" + element_type_->ToReprString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string IndexedSlicesType::ToString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "IndexedSlices";
|
||||||
|
}
|
||||||
|
return "IndexedSlices[" + element_type_->ToString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string IndexedSlicesType::DumpText() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "IndexedSlices";
|
||||||
|
}
|
||||||
|
return "IndexedSlices[" + element_type_->DumpText() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IndexedSlicesType::operator==(const Type &other) const {
|
||||||
|
if (!IsSameObjectType(*this, other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto other_elem_type = static_cast<const IndexedSlicesType &>(other).element_type_;
|
||||||
|
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||||
|
return true;
|
||||||
|
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *element_type_ == *other_elem_type;
|
||||||
|
}
|
||||||
|
|
||||||
Function::Function() : Object(kObjectTypeFunction) {
|
Function::Function() : Object(kObjectTypeFunction) {
|
||||||
args_ = std::vector<TypePtr>();
|
args_ = std::vector<TypePtr>();
|
||||||
retval_ = nullptr;
|
retval_ = nullptr;
|
||||||
|
|
|
@ -108,10 +108,34 @@ class Slice : public Object {
|
||||||
};
|
};
|
||||||
using SlicePtr = std::shared_ptr<Slice>;
|
using SlicePtr = std::shared_ptr<Slice>;
|
||||||
|
|
||||||
|
class UndeterminedType : public Object {
|
||||||
|
public:
|
||||||
|
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
|
||||||
|
explicit UndeterminedType(const TypePtr &ele)
|
||||||
|
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
|
||||||
|
~UndeterminedType() override = default;
|
||||||
|
MS_DECLARE_PARENT(UndeterminedType, Object)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
|
||||||
|
const TypePtr element() const { return element_type_; }
|
||||||
|
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||||
|
|
||||||
|
TypePtr DeepCopy() const override;
|
||||||
|
std::string ToString() const override;
|
||||||
|
std::string ToReprString() const override;
|
||||||
|
std::string DumpText() const override;
|
||||||
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TypePtr element_type_;
|
||||||
|
};
|
||||||
|
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
|
||||||
|
|
||||||
class TensorType : public Object {
|
class TensorType : public Object {
|
||||||
public:
|
public:
|
||||||
TensorType() : Object(kObjectTypeTensorType) {}
|
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
|
||||||
explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
|
explicit TensorType(const TypePtr &ele)
|
||||||
|
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||||
~TensorType() override = default;
|
~TensorType() override = default;
|
||||||
MS_DECLARE_PARENT(TensorType, Object)
|
MS_DECLARE_PARENT(TensorType, Object)
|
||||||
|
|
||||||
|
@ -130,6 +154,29 @@ class TensorType : public Object {
|
||||||
};
|
};
|
||||||
using TensorTypePtr = std::shared_ptr<TensorType>;
|
using TensorTypePtr = std::shared_ptr<TensorType>;
|
||||||
|
|
||||||
|
class IndexedSlicesType : public Object {
|
||||||
|
public:
|
||||||
|
IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {}
|
||||||
|
explicit IndexedSlicesType(const TypePtr &ele)
|
||||||
|
: Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||||
|
~IndexedSlicesType() override = default;
|
||||||
|
MS_DECLARE_PARENT(IndexedSlicesType, Object)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; }
|
||||||
|
const TypePtr element() const { return element_type_; }
|
||||||
|
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||||
|
|
||||||
|
TypePtr DeepCopy() const override;
|
||||||
|
std::string ToString() const override;
|
||||||
|
std::string ToReprString() const override;
|
||||||
|
std::string DumpText() const override;
|
||||||
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TypePtr element_type_;
|
||||||
|
};
|
||||||
|
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
|
||||||
|
|
||||||
class Function : public Object {
|
class Function : public Object {
|
||||||
public:
|
public:
|
||||||
Function();
|
Function();
|
||||||
|
@ -255,6 +302,8 @@ TypePtr StringToType(const std::string &type_name);
|
||||||
// Judge whether x is predicate or is a subclass of predicate.
|
// Judge whether x is predicate or is a subclass of predicate.
|
||||||
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
|
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
|
||||||
|
|
||||||
|
bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type);
|
||||||
|
|
||||||
// Whether t1 is identity or a subclass of t2.
|
// Whether t1 is identity or a subclass of t2.
|
||||||
bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
|
bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
|
||||||
|
|
||||||
|
|
|
@ -115,6 +115,10 @@ const char *ObjectIdLabel(const TypeId &v) {
|
||||||
return "kObjectTypeKeyword";
|
return "kObjectTypeKeyword";
|
||||||
case kObjectTypeTensorType:
|
case kObjectTypeTensorType:
|
||||||
return "kObjectTypeTensorType";
|
return "kObjectTypeTensorType";
|
||||||
|
case kObjectTypeIndexedSlicesType:
|
||||||
|
return "kObjectTypeIndexedSlicesType";
|
||||||
|
case kObjectTypeUndeterminedType:
|
||||||
|
return "kObjectTypeUndeterminedType";
|
||||||
case kObjectTypeDictionary:
|
case kObjectTypeDictionary:
|
||||||
return "kObjectTypeDictionary";
|
return "kObjectTypeDictionary";
|
||||||
case kObjectTypeClass:
|
case kObjectTypeClass:
|
||||||
|
|
|
@ -67,6 +67,7 @@ class Type : public Value {
|
||||||
virtual bool equal(const TypePtr other) const { return *this == *other; }
|
virtual bool equal(const TypePtr other) const { return *this == *other; }
|
||||||
|
|
||||||
virtual TypeId object_type() const { return kTypeUnknown; }
|
virtual TypeId object_type() const { return kTypeUnknown; }
|
||||||
|
virtual TypeId parent_type() const { return kTypeUnknown; }
|
||||||
virtual TypeId number_type() const { return kTypeUnknown; }
|
virtual TypeId number_type() const { return kTypeUnknown; }
|
||||||
virtual TypePtr DeepCopy() const = 0;
|
virtual TypePtr DeepCopy() const = 0;
|
||||||
virtual TypePtr Clone() const { return DeepCopy(); }
|
virtual TypePtr Clone() const { return DeepCopy(); }
|
||||||
|
@ -97,13 +98,16 @@ using TypePtrList = std::vector<TypePtr>;
|
||||||
//
|
//
|
||||||
class Object : public Type {
|
class Object : public Type {
|
||||||
public:
|
public:
|
||||||
Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject) {}
|
Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {}
|
||||||
explicit Object(const TypeId object_type, bool is_generic = true)
|
explicit Object(const TypeId object_type, bool is_generic = true)
|
||||||
: Type(kMetaTypeObject, is_generic), object_type_(object_type) {}
|
: Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {}
|
||||||
|
explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true)
|
||||||
|
: Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {}
|
||||||
~Object() override = default;
|
~Object() override = default;
|
||||||
MS_DECLARE_PARENT(Object, Type)
|
MS_DECLARE_PARENT(Object, Type)
|
||||||
|
|
||||||
TypeId object_type() const override { return object_type_; }
|
TypeId object_type() const override { return object_type_; }
|
||||||
|
TypeId parent_type() const override { return parent_type_; }
|
||||||
TypeId type_id() const override { return object_type_; }
|
TypeId type_id() const override { return object_type_; }
|
||||||
TypeId generic_type_id() const override { return kMetaTypeObject; }
|
TypeId generic_type_id() const override { return kMetaTypeObject; }
|
||||||
bool equal(const TypePtr other) const override;
|
bool equal(const TypePtr other) const override;
|
||||||
|
@ -114,6 +118,7 @@ class Object : public Type {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const TypeId object_type_;
|
const TypeId object_type_;
|
||||||
|
const TypeId parent_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
|
std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
|
||||||
|
|
|
@ -50,6 +50,8 @@ enum TypeId : int {
|
||||||
kObjectTypeSlice,
|
kObjectTypeSlice,
|
||||||
kObjectTypeKeyword,
|
kObjectTypeKeyword,
|
||||||
kObjectTypeTensorType,
|
kObjectTypeTensorType,
|
||||||
|
kObjectTypeIndexedSlicesType,
|
||||||
|
kObjectTypeUndeterminedType,
|
||||||
kObjectTypeClass,
|
kObjectTypeClass,
|
||||||
kObjectTypeDictionary,
|
kObjectTypeDictionary,
|
||||||
kObjectTypeFunction,
|
kObjectTypeFunction,
|
||||||
|
|
|
@ -192,6 +192,40 @@ TypePtr TensorStrToType(const std::string &type_name) {
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypePtr IndexedSlicesStrToType(const std::string &type_name) {
|
||||||
|
if (type_name == "IndexedSlices") {
|
||||||
|
return std::make_shared<IndexedSlicesType>();
|
||||||
|
}
|
||||||
|
auto start = type_name.find_first_of('[') + 1;
|
||||||
|
auto end = type_name.find_last_of(']');
|
||||||
|
if (start >= type_name.size()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto element_str = type_name.substr(start, end - start);
|
||||||
|
auto element_type = StringToType(element_str);
|
||||||
|
if (element_type == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return std::make_shared<IndexedSlicesType>(element_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr UndeterminedStrToType(const std::string &type_name) {
|
||||||
|
if (type_name == "Undetermined") {
|
||||||
|
return std::make_shared<UndeterminedType>();
|
||||||
|
}
|
||||||
|
auto start = type_name.find_first_of('[') + 1;
|
||||||
|
auto end = type_name.find_last_of(']');
|
||||||
|
if (start >= type_name.size()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto element_str = type_name.substr(start, end - start);
|
||||||
|
auto element_type = StringToType(element_str);
|
||||||
|
if (element_type == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return std::make_shared<UndeterminedType>(element_type);
|
||||||
|
}
|
||||||
|
|
||||||
TypePtr ListStrToType(const std::string &type_name) {
|
TypePtr ListStrToType(const std::string &type_name) {
|
||||||
TypePtr type = nullptr;
|
TypePtr type = nullptr;
|
||||||
if (type_name == "List") {
|
if (type_name == "List") {
|
||||||
|
@ -313,6 +347,10 @@ TypePtr StringToType(const std::string &type_name) {
|
||||||
type = StringToNumberType<Float>(type_name, "Float");
|
type = StringToNumberType<Float>(type_name, "Float");
|
||||||
} else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
|
} else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
|
||||||
type = TensorStrToType(type_name);
|
type = TensorStrToType(type_name);
|
||||||
|
} else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) {
|
||||||
|
type = UndeterminedStrToType(type_name);
|
||||||
|
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
|
||||||
|
type = IndexedSlicesStrToType(type_name);
|
||||||
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
||||||
type = ListStrToType(type_name);
|
type = ListStrToType(type_name);
|
||||||
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
||||||
|
@ -340,6 +378,20 @@ TypePtr StringToType(const std::string &type_name) {
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) {
|
||||||
|
if (x == nullptr || base_type == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Type is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
||||||
if (x == nullptr || base_type == nullptr) {
|
if (x == nullptr || base_type == nullptr) {
|
||||||
MS_LOG(ERROR) << "Type is nullptr.";
|
MS_LOG(ERROR) << "Type is nullptr.";
|
||||||
|
@ -481,6 +533,10 @@ REGISTER_PYBIND_DEFINE(
|
||||||
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
|
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
|
||||||
return data;
|
return data;
|
||||||
}));
|
}));
|
||||||
|
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
|
||||||
|
.def(py::init());
|
||||||
|
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
|
||||||
|
.def(py::init());
|
||||||
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval"));
|
.def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval"));
|
||||||
|
@ -501,6 +557,8 @@ const TypePtr kTypeExternal = std::make_shared<External>();
|
||||||
const TypePtr kTypeEnv = std::make_shared<EnvType>();
|
const TypePtr kTypeEnv = std::make_shared<EnvType>();
|
||||||
const TypePtr kTypeType = std::make_shared<TypeType>();
|
const TypePtr kTypeType = std::make_shared<TypeType>();
|
||||||
const TypePtr kTensorType = std::make_shared<TensorType>();
|
const TypePtr kTensorType = std::make_shared<TensorType>();
|
||||||
|
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
|
||||||
|
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
|
||||||
const TypePtr kString = std::make_shared<String>();
|
const TypePtr kString = std::make_shared<String>();
|
||||||
const TypePtr kList = std::make_shared<List>();
|
const TypePtr kList = std::make_shared<List>();
|
||||||
const TypePtr kTuple = std::make_shared<Tuple>();
|
const TypePtr kTuple = std::make_shared<Tuple>();
|
||||||
|
|
|
@ -93,15 +93,17 @@ static TypePtr UnwrapRef(const TypePtr &type) {
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
|
||||||
bool find_fn = false;
|
// Return Exact match if exists, else return non ambiguous sub class match
|
||||||
py::function py_fn;
|
// Return py::none() if matching is ambiguous
|
||||||
|
const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
|
||||||
|
// Exact match
|
||||||
for (auto &item : fn_cache_py_) {
|
for (auto &item : fn_cache_py_) {
|
||||||
TypePtrList sign = item.first;
|
TypePtrList sign = item.first;
|
||||||
if (sign.size() != types.size()) {
|
if (sign.size() != types.size()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
bool match = true;
|
auto match = true;
|
||||||
for (size_t i = 0; i < sign.size(); ++i) {
|
for (size_t i = 0; i < sign.size(); ++i) {
|
||||||
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
|
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
|
||||||
match = false;
|
match = false;
|
||||||
|
@ -111,13 +113,45 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
||||||
if (!match) {
|
if (!match) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
find_fn = true;
|
return item.second;
|
||||||
py_fn = item.second;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
// Try best match
|
||||||
|
py::function py_fn_subclass;
|
||||||
|
size_t subclass_match_cnt = 0;
|
||||||
|
for (auto &item : fn_cache_py_) {
|
||||||
|
TypePtrList sign = item.first;
|
||||||
|
if (sign.size() != types.size()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto match = true;
|
||||||
|
for (size_t i = 0; i < sign.size(); ++i) {
|
||||||
|
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) &&
|
||||||
|
!IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) {
|
||||||
|
match = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!match) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
py_fn_subclass = item.second;
|
||||||
|
subclass_match_cnt++;
|
||||||
|
}
|
||||||
|
if (subclass_match_cnt > 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass";
|
||||||
|
}
|
||||||
|
if (subclass_match_cnt == 1) {
|
||||||
|
MS_LOG(DEBUG) << "Found one subclass match";
|
||||||
|
return py_fn_subclass;
|
||||||
|
}
|
||||||
|
return py::none();
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
||||||
|
auto py_fn = SignMatch(types);
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
buffer << types;
|
buffer << types;
|
||||||
if (find_fn) {
|
if (py_fn != py::none()) {
|
||||||
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
|
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
|
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
|
||||||
|
|
|
@ -54,6 +54,7 @@ class MultitypeFuncGraph : public MetaFuncGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
const py::function SignMatch(const TypePtrList &types);
|
||||||
std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_;
|
std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_;
|
||||||
std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_;
|
std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -277,5 +277,12 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary
|
||||||
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
||||||
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
||||||
const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
|
const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
|
||||||
|
|
||||||
|
// IndexedSlices
|
||||||
|
const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices");
|
||||||
|
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
|
||||||
|
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
|
||||||
|
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
|
||||||
|
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -287,6 +287,13 @@ extern const PrimitivePtr kPrimMirror;
|
||||||
extern const PrimitivePtr kPrimVirtualDiv;
|
extern const PrimitivePtr kPrimVirtualDiv;
|
||||||
extern const PrimitivePtr kPrimVirtualDataset;
|
extern const PrimitivePtr kPrimVirtualDataset;
|
||||||
|
|
||||||
|
// IndexedSlices
|
||||||
|
extern const PrimitivePtr kPrimMakeIndexedSlices;
|
||||||
|
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
|
||||||
|
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
|
||||||
|
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
|
||||||
|
extern const PrimitivePtr kPrimIsIndexedSlices;
|
||||||
|
|
||||||
class DoSignaturePrimitive : public Primitive {
|
class DoSignaturePrimitive : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "pipeline/static_analysis/prim.h"
|
#include "pipeline/static_analysis/prim.h"
|
||||||
#include "pipeline/static_analysis/utils.h"
|
#include "pipeline/static_analysis/utils.h"
|
||||||
#include "utils/symbolic.h"
|
#include "utils/symbolic.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
@ -173,6 +174,13 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
|
||||||
return std::make_shared<AbstractTuple>(sparse_list);
|
return std::make_shared<AbstractTuple>(sparse_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||||
|
if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) {
|
||||||
|
auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
|
||||||
|
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
|
||||||
|
}
|
||||||
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
|
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
|
||||||
return dflt;
|
return dflt;
|
||||||
}
|
}
|
||||||
|
@ -236,6 +244,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
|
||||||
}
|
}
|
||||||
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||||
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
|
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
|
||||||
|
ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad());
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,5 +446,72 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
|
||||||
}
|
}
|
||||||
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
|
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tensors and a tuple.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 3);
|
||||||
|
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||||
|
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
|
||||||
|
|
||||||
|
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||||
|
auto shp = dense_shape_value->value();
|
||||||
|
std::vector<int> dense_shape_vec;
|
||||||
|
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
|
||||||
|
[](const ValuePtr &e) -> int {
|
||||||
|
auto elem = GetValue<int>(e);
|
||||||
|
return elem;
|
||||||
|
});
|
||||||
|
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
||||||
|
ret->set_indices(indices);
|
||||||
|
ret->set_values(values);
|
||||||
|
ret->set_dense_shape(dense_shape);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tensors and a tuple.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(indexed_slices->values());
|
||||||
|
return indexed_slices->values();
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tensors and a tuple.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(indexed_slices->indices());
|
||||||
|
return indexed_slices->indices();
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tensors and a tuple.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape());
|
||||||
|
return indexed_slices->dense_shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
bool ret = false;
|
||||||
|
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) {
|
||||||
|
ret = true;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString();
|
||||||
|
return std::make_shared<AbstractScalar>(ret);
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -36,6 +36,7 @@ using mindspore::abstract::AbstractJTagged;
|
||||||
using mindspore::abstract::AbstractList;
|
using mindspore::abstract::AbstractList;
|
||||||
using mindspore::abstract::AbstractScalar;
|
using mindspore::abstract::AbstractScalar;
|
||||||
using mindspore::abstract::AbstractTuple;
|
using mindspore::abstract::AbstractTuple;
|
||||||
|
using mindspore::abstract::AbstractUndetermined;
|
||||||
|
|
||||||
static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
|
static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
|
||||||
if (t == nullptr) {
|
if (t == nullptr) {
|
||||||
|
@ -78,7 +79,7 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(cons);
|
MS_EXCEPTION_IF_NULL(cons);
|
||||||
|
|
||||||
auto dt = data->abstract();
|
auto dt = data->abstract();
|
||||||
if (dt == nullptr) {
|
if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@
|
||||||
#include "optimizer/irpass/tile_eliminate.h"
|
#include "optimizer/irpass/tile_eliminate.h"
|
||||||
#include "optimizer/irpass/transpose_eliminate.h"
|
#include "optimizer/irpass/transpose_eliminate.h"
|
||||||
#include "optimizer/opt.h"
|
#include "optimizer/opt.h"
|
||||||
|
#include "optimizer/irpass/indexed_slices_eliminate.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -153,6 +154,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
// Mark interface fusion
|
// Mark interface fusion
|
||||||
mark_interface_fusion_ =
|
mark_interface_fusion_ =
|
||||||
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
|
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
|
||||||
|
|
||||||
|
// IndexedSlices Eliminate
|
||||||
|
indexed_slices_eliminate_ = MakeSubstitution(
|
||||||
|
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
|
||||||
|
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
|
||||||
}
|
}
|
||||||
|
|
||||||
ResolveIRPassLib::ResolveIRPassLib() {
|
ResolveIRPassLib::ResolveIRPassLib() {
|
||||||
|
|
|
@ -104,6 +104,9 @@ class OptimizeIRPassLib {
|
||||||
|
|
||||||
// Fusion
|
// Fusion
|
||||||
SubstitutionPtr mark_interface_fusion_;
|
SubstitutionPtr mark_interface_fusion_;
|
||||||
|
|
||||||
|
// IndexedSlices Eliminate
|
||||||
|
SubstitutionPtr indexed_slices_eliminate_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// the collection of irpass for resolve action
|
// the collection of irpass for resolve action
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
||||||
|
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "optimizer/irpass.h"
|
||||||
|
#include "optimizer/optimizer.h"
|
||||||
|
#include "ir/visitor.h"
|
||||||
|
#include "operator/ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||||
|
// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||||
|
// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||||
|
class IndexedSlicesEliminater : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
Reset();
|
||||||
|
AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node);
|
||||||
|
|
||||||
|
if (is_match_) {
|
||||||
|
return tuple_->input(1);
|
||||||
|
}
|
||||||
|
AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node);
|
||||||
|
|
||||||
|
if (is_match_) {
|
||||||
|
return tuple_->input(2);
|
||||||
|
}
|
||||||
|
AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node);
|
||||||
|
|
||||||
|
if (is_match_) {
|
||||||
|
return tuple_->input(3);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit(const CNodePtr &cnode) override {
|
||||||
|
if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) {
|
||||||
|
tuple_ = cnode;
|
||||||
|
is_match_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset() {
|
||||||
|
tuple_ = nullptr;
|
||||||
|
is_match_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_match_{false};
|
||||||
|
CNodePtr tuple_{nullptr};
|
||||||
|
};
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
|
@ -232,6 +232,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
||||||
auto sparse_grad =
|
auto sparse_grad =
|
||||||
py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad"));
|
py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad"));
|
||||||
ptr->set_sparse_grad(sparse_grad);
|
ptr->set_sparse_grad(sparse_grad);
|
||||||
|
auto has_indexed_slices_grad =
|
||||||
|
py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "has_indexed_slices_grad"));
|
||||||
|
ptr->set_has_indexed_slices_grad(has_indexed_slices_grad);
|
||||||
|
|
||||||
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
||||||
args_spec.push_back(ptr);
|
args_spec.push_back(ptr);
|
||||||
|
|
|
@ -154,7 +154,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
|
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
|
||||||
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
|
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
|
||||||
"Set the GraphKernel switch to on or off.")
|
"Set the GraphKernel switch to on or off.")
|
||||||
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.");
|
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
|
||||||
|
.def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.")
|
||||||
|
.def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse.");
|
||||||
|
|
||||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||||
|
|
|
@ -156,6 +156,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
irpass.replace_refkey_by_param_,
|
irpass.replace_refkey_by_param_,
|
||||||
irpass.make_ref_eliminate_,
|
irpass.make_ref_eliminate_,
|
||||||
irpass.get_ref_param_eliminate_,
|
irpass.get_ref_param_eliminate_,
|
||||||
|
irpass.indexed_slices_eliminate_,
|
||||||
});
|
});
|
||||||
OptPassGroupMap map({
|
OptPassGroupMap map({
|
||||||
{"b_1", b_1},
|
{"b_1", b_1},
|
||||||
|
|
|
@ -33,148 +33,157 @@ namespace mindspore {
|
||||||
namespace pipeline {
|
namespace pipeline {
|
||||||
|
|
||||||
MethodMap &GetMethodMap() {
|
MethodMap &GetMethodMap() {
|
||||||
static MethodMap method_map = {{kObjectTypeString,
|
static MethodMap method_map = {
|
||||||
{
|
{kObjectTypeString,
|
||||||
{"__bool__", std::string("str_bool")} // C.str_bool
|
{
|
||||||
}},
|
{"__bool__", std::string("str_bool")} // C.str_bool
|
||||||
{kMetaTypeNone,
|
}},
|
||||||
{
|
{kMetaTypeNone,
|
||||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
{
|
||||||
}},
|
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||||
{kNumberTypeBool,
|
}},
|
||||||
{
|
{kNumberTypeBool,
|
||||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
{
|
||||||
{"__or__", prim::kPrimBoolOr}, // P.bool_or
|
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||||
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
|
{"__or__", prim::kPrimBoolOr}, // P.bool_or
|
||||||
{"__ne__", std::string("bool_ne")}, // C.bool_ne
|
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
|
||||||
{"__bool__", prim::kPrimIdentity} // P.identity
|
{"__ne__", std::string("bool_ne")}, // C.bool_ne
|
||||||
}},
|
{"__bool__", prim::kPrimIdentity} // P.identity
|
||||||
{kNumberTypeInt,
|
}},
|
||||||
{
|
{kNumberTypeInt,
|
||||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
|
{
|
||||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
|
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
|
||||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
|
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
|
||||||
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
|
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
|
||||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
|
||||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
|
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
|
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
|
||||||
{"__floor__", prim::kPrimIdentity}, // P.identity
|
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
|
||||||
{"__trunc__", prim::kPrimIdentity}, // P.identity
|
{"__floor__", prim::kPrimIdentity}, // P.identity
|
||||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
|
{"__trunc__", prim::kPrimIdentity}, // P.identity
|
||||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
|
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
|
||||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
|
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
|
||||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
|
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
|
||||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
|
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
|
||||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
|
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
|
||||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
|
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
|
||||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
|
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
|
||||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
|
||||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
|
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||||
}},
|
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
|
||||||
{kNumberTypeUInt,
|
}},
|
||||||
{
|
{kNumberTypeUInt,
|
||||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
{
|
||||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||||
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||||
{"__floor__", prim::kPrimIdentity}, // P.identity,
|
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||||
{"__trunc__", prim::kPrimIdentity}, // P.identity,
|
{"__floor__", prim::kPrimIdentity}, // P.identity,
|
||||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
{"__trunc__", prim::kPrimIdentity}, // P.identity,
|
||||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||||
}},
|
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||||
{kNumberTypeFloat,
|
}},
|
||||||
{
|
{kNumberTypeFloat,
|
||||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
{
|
||||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||||
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
|
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||||
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
|
||||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||||
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
|
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||||
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
|
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
|
||||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
|
||||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||||
{"__bool__", std::string("float_bool")}, // C.float_bool
|
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
{"__bool__", std::string("float_bool")}, // C.float_bool
|
||||||
}},
|
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||||
{kObjectTypeTuple,
|
}},
|
||||||
{
|
{kObjectTypeTuple,
|
||||||
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
|
{
|
||||||
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
|
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
|
||||||
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
|
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
|
||||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
|
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
|
||||||
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
|
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
|
||||||
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
|
||||||
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
|
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
||||||
}},
|
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
|
||||||
{kObjectTypeList,
|
}},
|
||||||
{
|
{kObjectTypeList,
|
||||||
{"__len__", prim::kPrimListLen}, // P.list_len,
|
{
|
||||||
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
|
{"__len__", prim::kPrimListLen}, // P.list_len,
|
||||||
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
|
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
|
||||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
|
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
|
||||||
{"__ms_next__", std::string("list_next")}, // C.list_next
|
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
|
||||||
{"append", std::string("list_append")}, // C.list_next
|
{"__ms_next__", std::string("list_next")}, // C.list_next
|
||||||
{"__bool__", std::string("list_bool")}, // C.list_bool
|
{"append", std::string("list_append")}, // C.list_next
|
||||||
{"__ms_hasnext__", std::string("list_hasnext")},
|
{"__bool__", std::string("list_bool")}, // C.list_bool
|
||||||
}},
|
{"__ms_hasnext__", std::string("list_hasnext")},
|
||||||
{kObjectTypeDictionary,
|
}},
|
||||||
{
|
{kObjectTypeDictionary,
|
||||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
{
|
||||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||||
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||||
}},
|
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
||||||
{kObjectTypeTensorType,
|
}},
|
||||||
{
|
{kObjectTypeTensorType,
|
||||||
{"__add__", std::string("add")}, // C.add
|
{
|
||||||
{"__sub__", std::string("sub")}, // C.sub
|
{"__add__", std::string("add")}, // C.add
|
||||||
{"__mul__", std::string("mul")}, // C.mul
|
{"__sub__", std::string("sub")}, // C.sub
|
||||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
{"__mul__", std::string("mul")}, // C.mul
|
||||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||||
{"__mod__", std::string("mod")}, // C.mod
|
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||||
{"__pow__", std::string("pow_")}, // C.pow
|
{"__mod__", std::string("mod")}, // C.mod
|
||||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
{"__pow__", std::string("pow_")}, // C.pow
|
||||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||||
{"__eq__", std::string("eq")}, // C.eq
|
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||||
{"__ne__", std::string("ne")}, // C.ne
|
{"__eq__", std::string("eq")}, // C.eq
|
||||||
{"__lt__", std::string("lt")}, // C.lt
|
{"__ne__", std::string("ne")}, // C.ne
|
||||||
{"__gt__", std::string("gt")}, // C.gt
|
{"__lt__", std::string("lt")}, // C.lt
|
||||||
{"__le__", std::string("le")}, // C.le
|
{"__gt__", std::string("gt")}, // C.gt
|
||||||
{"__ge__", std::string("ge")}, // C.ge
|
{"__le__", std::string("le")}, // C.le
|
||||||
{"__matmul__", prim::kPrimDot}, // P.dot,
|
{"__ge__", std::string("ge")}, // C.ge
|
||||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
{"__matmul__", prim::kPrimDot}, // P.dot,
|
||||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||||
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||||
{"transpose", std::string("transpose")}, // P.transpose
|
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
||||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
{"transpose", std::string("transpose")}, // P.transpose
|
||||||
}},
|
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||||
{kObjectTypeJTagged, {}},
|
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
|
||||||
{kObjectTypeSymbolicKeyType, {}},
|
}},
|
||||||
{kObjectTypeEnvType, {}}};
|
{kObjectTypeIndexedSlicesType,
|
||||||
|
{
|
||||||
|
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
|
||||||
|
{"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values
|
||||||
|
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
|
||||||
|
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
|
||||||
|
}},
|
||||||
|
{kObjectTypeJTagged, {}},
|
||||||
|
{kObjectTypeSymbolicKeyType, {}},
|
||||||
|
{kObjectTypeEnvType, {}}};
|
||||||
return method_map;
|
return method_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,10 @@ bool AbstractBase::operator==(const AbstractBase &other) const {
|
||||||
if (tid() != other.tid()) {
|
if (tid() != other.tid()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (BuildType()->type_id() == kObjectTypeUndeterminedType &&
|
||||||
|
other.BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (value_ == nullptr || other.value_ == nullptr) {
|
if (value_ == nullptr || other.value_ == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: "
|
MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: "
|
||||||
<< this->ToString() << ", other: " << other.ToString();
|
<< this->ToString() << ", other: " << other.ToString();
|
||||||
|
@ -65,7 +69,7 @@ std::string AbstractBase::ToString() const {
|
||||||
MS_EXCEPTION_IF_NULL(shape_);
|
MS_EXCEPTION_IF_NULL(shape_);
|
||||||
buffer << type_name() << "("
|
buffer << type_name() << "("
|
||||||
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
|
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
|
||||||
<< " sparse_grad: " << sparse_grad_ << ")";
|
<< " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")";
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,6 +80,7 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
||||||
if (*this == *other) {
|
if (*this == *other) {
|
||||||
auto ret = shared_from_base<AbstractBase>();
|
auto ret = shared_from_base<AbstractBase>();
|
||||||
ret->set_sparse_grad(sparse_grad());
|
ret->set_sparse_grad(sparse_grad());
|
||||||
|
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
auto value_self = GetValueTrack();
|
auto value_self = GetValueTrack();
|
||||||
|
@ -85,10 +90,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
||||||
if (res_value == value_self) {
|
if (res_value == value_self) {
|
||||||
auto ret = shared_from_base<AbstractBase>();
|
auto ret = shared_from_base<AbstractBase>();
|
||||||
ret->set_sparse_grad(sparse_grad());
|
ret->set_sparse_grad(sparse_grad());
|
||||||
|
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
|
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
|
||||||
ret->set_sparse_grad(sparse_grad());
|
ret->set_sparse_grad(sparse_grad());
|
||||||
|
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -409,6 +416,14 @@ std::size_t AbstractSlice::hash() const {
|
||||||
return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
|
return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShapePtr AbstractUndetermined::shape() const {
|
||||||
|
auto shp = dyn_cast<Shape>(GetShapeTrack());
|
||||||
|
if (shp == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Tensor should have a shape.";
|
||||||
|
}
|
||||||
|
return shp;
|
||||||
|
}
|
||||||
|
|
||||||
TypePtr AbstractTensor::BuildType() const {
|
TypePtr AbstractTensor::BuildType() const {
|
||||||
MS_EXCEPTION_IF_NULL(element_);
|
MS_EXCEPTION_IF_NULL(element_);
|
||||||
TypePtr element_type = element_->BuildType();
|
TypePtr element_type = element_->BuildType();
|
||||||
|
@ -425,6 +440,13 @@ BaseShapePtr AbstractTensor::BuildShape() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
||||||
|
if (other->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||||
|
auto other_tensor = dyn_cast<AbstractUndetermined>(other);
|
||||||
|
auto element = element_->Join(other_tensor->element());
|
||||||
|
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
|
||||||
|
auto ret = std::make_shared<AbstractUndetermined>(element, shape);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
auto other_tensor = dyn_cast<AbstractTensor>(other);
|
auto other_tensor = dyn_cast<AbstractTensor>(other);
|
||||||
if (other_tensor == nullptr) {
|
if (other_tensor == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||||
|
@ -433,6 +455,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
||||||
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
|
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
|
||||||
auto ret = std::make_shared<AbstractTensor>(element, shape);
|
auto ret = std::make_shared<AbstractTensor>(element, shape);
|
||||||
ret->set_sparse_grad(sparse_grad());
|
ret->set_sparse_grad(sparse_grad());
|
||||||
|
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -474,6 +497,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
|
||||||
clone->set_shape(shp->Clone());
|
clone->set_shape(shp->Clone());
|
||||||
clone->set_value(GetValueTrack());
|
clone->set_value(GetValueTrack());
|
||||||
clone->set_sparse_grad(sparse_grad());
|
clone->set_sparse_grad(sparse_grad());
|
||||||
|
clone->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -484,6 +508,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
|
||||||
broaden->set_shape(shp->Clone());
|
broaden->set_shape(shp->Clone());
|
||||||
broaden->set_value(kAnyValue);
|
broaden->set_value(kAnyValue);
|
||||||
broaden->set_sparse_grad(sparse_grad());
|
broaden->set_sparse_grad(sparse_grad());
|
||||||
|
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||||
return broaden;
|
return broaden;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -495,17 +520,10 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
|
||||||
broaden->set_shape(shp);
|
broaden->set_shape(shp);
|
||||||
broaden->set_value(kAnyValue);
|
broaden->set_value(kAnyValue);
|
||||||
broaden->set_sparse_grad(sparse_grad());
|
broaden->set_sparse_grad(sparse_grad());
|
||||||
|
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||||
return broaden;
|
return broaden;
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapePtr AbstractTensor::shape() const {
|
|
||||||
auto shp = dyn_cast<Shape>(GetShapeTrack());
|
|
||||||
if (shp == nullptr) {
|
|
||||||
MS_LOG(EXCEPTION) << "Tensor should have a shape.";
|
|
||||||
}
|
|
||||||
return shp;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string AbstractTensor::ToString() const {
|
std::string AbstractTensor::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
BaseShapePtr shape_track = GetShapeTrack();
|
BaseShapePtr shape_track = GetShapeTrack();
|
||||||
|
@ -516,7 +534,7 @@ std::string AbstractTensor::ToString() const {
|
||||||
buffer << type_name() << "("
|
buffer << type_name() << "("
|
||||||
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
|
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
|
||||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
|
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
|
||||||
<< ")";
|
<< " has_indexed_slices_grad " << has_indexed_slices_grad() << ")";
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1019,5 +1037,64 @@ std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &arg
|
||||||
bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
|
bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
|
||||||
return AbstractBasePtrListDeepEqual(lhs, rhs);
|
return AbstractBasePtrListDeepEqual(lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IndexedSlices
|
||||||
|
TypePtr AbstractIndexedSlices::BuildType() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
TypePtr element_type = element()->BuildType();
|
||||||
|
return std::make_shared<IndexedSlicesType>(element_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr AbstractIndexedSlices::Clone() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
auto clone = std::make_shared<AbstractIndexedSlices>(element()->Clone());
|
||||||
|
ShapePtr shp = shape();
|
||||||
|
clone->set_shape(shp->Clone());
|
||||||
|
clone->set_value(GetValueTrack());
|
||||||
|
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr AbstractIndexedSlices::Broaden() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden());
|
||||||
|
auto shp = shape();
|
||||||
|
broaden->set_shape(shp->Clone());
|
||||||
|
broaden->set_value(kAnyValue);
|
||||||
|
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||||
|
return broaden;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden());
|
||||||
|
auto shp = shape()->Clone();
|
||||||
|
shp->Broaden();
|
||||||
|
broaden->set_shape(shp);
|
||||||
|
broaden->set_value(kAnyValue);
|
||||||
|
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||||
|
return broaden;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AbstractIndexedSlices::ToString() const {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
BaseShapePtr shape_track = GetShapeTrack();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_track);
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
auto value_track = GetValueTrack();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_track);
|
||||||
|
buffer << type_name() << "("
|
||||||
|
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
|
||||||
|
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
|
||||||
|
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
|
||||||
|
<< ", dense_shape: " << dense_shape_->ToString();
|
||||||
|
return buffer.str();
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -44,7 +44,7 @@ class AbstractBase : public Base {
|
||||||
public:
|
public:
|
||||||
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
|
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
|
||||||
const BaseShapePtr &shape = kNoShape)
|
const BaseShapePtr &shape = kNoShape)
|
||||||
: value_(value), type_(type), shape_(shape), sparse_grad_("") {}
|
: value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {}
|
||||||
~AbstractBase() override = default;
|
~AbstractBase() override = default;
|
||||||
MS_DECLARE_PARENT(AbstractBase, Base)
|
MS_DECLARE_PARENT(AbstractBase, Base)
|
||||||
|
|
||||||
|
@ -54,12 +54,16 @@ class AbstractBase : public Base {
|
||||||
virtual bool operator==(const AbstractBase &other) const;
|
virtual bool operator==(const AbstractBase &other) const;
|
||||||
void set_value(const ValuePtr &value) { value_ = value; }
|
void set_value(const ValuePtr &value) { value_ = value; }
|
||||||
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
|
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
|
||||||
|
void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) {
|
||||||
|
has_indexed_slices_grad_ = has_indexed_slices_grad;
|
||||||
|
}
|
||||||
void set_type(const TypePtr &type) { type_ = type; }
|
void set_type(const TypePtr &type) { type_ = type; }
|
||||||
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
|
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
|
||||||
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
|
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
|
||||||
const std::string &value_desc() const { return value_desc_; }
|
const std::string &value_desc() const { return value_desc_; }
|
||||||
ValuePtr GetValueTrack() const { return value_; }
|
ValuePtr GetValueTrack() const { return value_; }
|
||||||
const std::string &sparse_grad() const { return sparse_grad_; }
|
const std::string &sparse_grad() const { return sparse_grad_; }
|
||||||
|
const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
|
||||||
TypePtr GetTypeTrack() const { return type_; }
|
TypePtr GetTypeTrack() const { return type_; }
|
||||||
BaseShapePtr GetShapeTrack() const { return shape_; }
|
BaseShapePtr GetShapeTrack() const { return shape_; }
|
||||||
|
|
||||||
|
@ -88,6 +92,7 @@ class AbstractBase : public Base {
|
||||||
BaseShapePtr shape_;
|
BaseShapePtr shape_;
|
||||||
std::string value_desc_; // store initial value description for error report
|
std::string value_desc_; // store initial value description for error report
|
||||||
std::string sparse_grad_;
|
std::string sparse_grad_;
|
||||||
|
bool has_indexed_slices_grad_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class AbstractScalar : public AbstractBase {
|
class AbstractScalar : public AbstractBase {
|
||||||
|
@ -231,35 +236,49 @@ class AbstractKeywordArg : public AbstractBase {
|
||||||
};
|
};
|
||||||
using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
|
using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
|
||||||
|
|
||||||
class AbstractTensor : public AbstractBase {
|
class AbstractUndetermined : public AbstractBase {
|
||||||
public:
|
public:
|
||||||
|
// shape and type are all unknown
|
||||||
|
AbstractUndetermined() : AbstractBase(kAnyValue) {}
|
||||||
// only element_ and value, shape track are valid member, type track are unknown.
|
// only element_ and value, shape track are valid member, type track are unknown.
|
||||||
explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||||
: AbstractBase(kAnyValue), element_(element) {
|
: AbstractBase(kAnyValue), element_(element) {
|
||||||
if (element == nullptr) {
|
if (element == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "element is nullptr";
|
MS_LOG(EXCEPTION) << "element is nullptr";
|
||||||
}
|
}
|
||||||
if (element->isa<AbstractTensor>()) {
|
if (element->isa<AbstractUndetermined>()) {
|
||||||
MS_LOG(EXCEPTION) << "element type error";
|
MS_LOG(EXCEPTION) << "element type error";
|
||||||
}
|
}
|
||||||
set_shape(shape);
|
set_shape(shape);
|
||||||
}
|
}
|
||||||
AbstractTensor(const TypePtr &element_type, const std::vector<int> &shape)
|
AbstractUndetermined(const TypePtr &element_type, const std::vector<int> &shape)
|
||||||
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
|
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
|
||||||
if (element_type == nullptr) {
|
if (element_type == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "element_type is nullptr";
|
MS_LOG(EXCEPTION) << "element_type is nullptr";
|
||||||
}
|
}
|
||||||
set_shape(std::make_shared<Shape>(shape));
|
set_shape(std::make_shared<Shape>(shape));
|
||||||
}
|
}
|
||||||
explicit AbstractTensor(const tensor::TensorPtr &tensor)
|
~AbstractUndetermined() override = default;
|
||||||
: AbstractBase(tensor), element_(std::make_shared<AbstractScalar>(kAnyValue, tensor->Dtype())) {
|
MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase)
|
||||||
if (tensor == nullptr) {
|
TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); }
|
||||||
MS_LOG(EXCEPTION) << "tensor is nullptr";
|
AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); }
|
||||||
}
|
const AbstractBasePtr element() const { return element_; }
|
||||||
set_shape(std::make_shared<Shape>(tensor->shape()));
|
ShapePtr shape() const;
|
||||||
}
|
|
||||||
|
protected:
|
||||||
|
AbstractBasePtr element_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AbstractTensor : public AbstractUndetermined {
|
||||||
|
public:
|
||||||
|
// only element_ and value, shape track are valid member, type track are unknown.
|
||||||
|
explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||||
|
: AbstractUndetermined(element, shape) {}
|
||||||
|
AbstractTensor(const TypePtr &element_type, const std::vector<int> &shape)
|
||||||
|
: AbstractUndetermined(element_type, shape) {}
|
||||||
|
explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {}
|
||||||
~AbstractTensor() override = default;
|
~AbstractTensor() override = default;
|
||||||
MS_DECLARE_PARENT(AbstractTensor, AbstractBase)
|
MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined)
|
||||||
|
|
||||||
TypePtr BuildType() const override;
|
TypePtr BuildType() const override;
|
||||||
BaseShapePtr BuildShape() const override;
|
BaseShapePtr BuildShape() const override;
|
||||||
|
@ -271,9 +290,7 @@ class AbstractTensor : public AbstractBase {
|
||||||
bool operator==(const AbstractTensor &other) const;
|
bool operator==(const AbstractTensor &other) const;
|
||||||
bool operator==(const AbstractBase &other) const override;
|
bool operator==(const AbstractBase &other) const override;
|
||||||
|
|
||||||
ShapePtr shape() const;
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
const AbstractBasePtr element() const { return element_; }
|
|
||||||
std::size_t hash() const override {
|
std::size_t hash() const override {
|
||||||
auto value = GetValueTrack();
|
auto value = GetValueTrack();
|
||||||
auto hash_sum = hash_combine(tid(), element_->hash());
|
auto hash_sum = hash_combine(tid(), element_->hash());
|
||||||
|
@ -285,9 +302,6 @@ class AbstractTensor : public AbstractBase {
|
||||||
}
|
}
|
||||||
return hash_sum;
|
return hash_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
AbstractBasePtr element_;
|
|
||||||
};
|
};
|
||||||
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
|
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
|
||||||
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
|
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
|
||||||
|
@ -585,6 +599,35 @@ struct AbstractBasePtrListEqual {
|
||||||
|
|
||||||
std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list);
|
std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list);
|
||||||
bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
|
bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
|
||||||
|
|
||||||
|
// IndexedSlices
|
||||||
|
class AbstractIndexedSlices : public AbstractUndetermined {
|
||||||
|
public:
|
||||||
|
explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||||
|
: AbstractUndetermined(element, shape) {}
|
||||||
|
AbstractIndexedSlices(const TypePtr &element_type, const std::vector<int> &shape)
|
||||||
|
: AbstractUndetermined(element_type, shape) {}
|
||||||
|
~AbstractIndexedSlices() override = default;
|
||||||
|
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
|
||||||
|
|
||||||
|
const AbstractTensorPtr indices() const { return indices_; }
|
||||||
|
const AbstractTensorPtr values() const { return values_; }
|
||||||
|
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
||||||
|
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
||||||
|
void set_values(const AbstractTensorPtr &values) { values_ = values; }
|
||||||
|
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||||
|
TypePtr BuildType() const override;
|
||||||
|
AbstractBasePtr Clone() const override;
|
||||||
|
AbstractBasePtr Broaden() const override;
|
||||||
|
AbstractBasePtr BroadenWithShape() const;
|
||||||
|
|
||||||
|
std::string ToString() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractTensorPtr indices_;
|
||||||
|
AbstractTensorPtr values_;
|
||||||
|
AbstractTuplePtr dense_shape_;
|
||||||
|
};
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_
|
#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_
|
||||||
|
|
|
@ -58,6 +58,20 @@ class Evaluator : public Base {
|
||||||
return args_spec_list;
|
return args_spec_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
|
||||||
|
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
|
||||||
|
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
if (is_abstract) {
|
||||||
|
MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result";
|
||||||
|
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
std::string ToString() const override { return identifier_; }
|
std::string ToString() const override { return identifier_; }
|
||||||
|
|
||||||
virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }
|
virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }
|
||||||
|
|
|
@ -66,6 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(Type)
|
ABSTRACT_REPORT_NAME_TRAITS(Type)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
||||||
|
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {
|
std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {
|
||||||
|
|
|
@ -36,6 +36,7 @@
|
||||||
#include "pipeline/parse/resolve.h"
|
#include "pipeline/parse/resolve.h"
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
#include "pipeline/parse/data_converter.h"
|
#include "pipeline/parse/data_converter.h"
|
||||||
#include "pipeline/static_analysis/param_validator.h"
|
#include "pipeline/static_analysis/param_validator.h"
|
||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
|
@ -132,6 +133,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
||||||
// Debug
|
// Debug
|
||||||
{prim::kPrimDebug, {InferImplDebug, true}},
|
{prim::kPrimDebug, {InferImplDebug, true}},
|
||||||
|
// IndexedSlices
|
||||||
|
{prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}},
|
||||||
|
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
|
||||||
|
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
|
||||||
|
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
|
||||||
|
{prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}},
|
||||||
};
|
};
|
||||||
return prim_eval_implement_map;
|
return prim_eval_implement_map;
|
||||||
}
|
}
|
||||||
|
@ -139,6 +146,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
using mindspore::parse::PyObjectWrapper;
|
using mindspore::parse::PyObjectWrapper;
|
||||||
|
|
||||||
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||||
|
auto context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||||
|
if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
|
||||||
|
auto ret_abstract = AbstractEval(args);
|
||||||
|
if (ret_abstract != nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
|
||||||
|
return ret_abstract;
|
||||||
|
}
|
||||||
|
}
|
||||||
prim_->BeginRecordAddAttr();
|
prim_->BeginRecordAddAttr();
|
||||||
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
||||||
prim_->EndRecordAddAttr();
|
prim_->EndRecordAddAttr();
|
||||||
|
@ -485,6 +502,16 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||||
|
auto context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||||
|
if (enable_sparse_flag) {
|
||||||
|
auto ret_abstract = AbstractEval(args);
|
||||||
|
if (ret_abstract != nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
|
||||||
|
return ret_abstract;
|
||||||
|
}
|
||||||
|
}
|
||||||
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
||||||
|
|
||||||
const auto &iter = cache_->find(args);
|
const auto &iter = cache_->find(args);
|
||||||
|
@ -512,6 +539,16 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
|
||||||
}
|
}
|
||||||
|
|
||||||
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||||
|
auto context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||||
|
if (enable_sparse_flag) {
|
||||||
|
auto ret_abstract = AbstractEval(args);
|
||||||
|
if (ret_abstract != nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
|
||||||
|
return ret_abstract;
|
||||||
|
}
|
||||||
|
}
|
||||||
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
|
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
|
||||||
if (nargs_ != args.size()) {
|
if (nargs_ != args.size()) {
|
||||||
MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
|
MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
|
||||||
|
@ -871,6 +908,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
auto ref_value = ref_abs->ref();
|
auto ref_value = ref_abs->ref();
|
||||||
MS_EXCEPTION_IF_NULL(ref_value);
|
MS_EXCEPTION_IF_NULL(ref_value);
|
||||||
ret->set_sparse_grad(ref_value->sparse_grad());
|
ret->set_sparse_grad(ref_value->sparse_grad());
|
||||||
|
ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad());
|
||||||
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -886,6 +924,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
||||||
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
||||||
abs_scalar->set_sparse_grad(x->sparse_grad());
|
abs_scalar->set_sparse_grad(x->sparse_grad());
|
||||||
|
abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad());
|
||||||
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -897,6 +936,16 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
||||||
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
||||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||||
|
auto context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||||
|
if (enable_sparse_flag) {
|
||||||
|
auto ret_abstract = AbstractEval(args_spec_list);
|
||||||
|
if (ret_abstract != nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
|
||||||
|
return ret_abstract;
|
||||||
|
}
|
||||||
|
}
|
||||||
// Inputs: data, item
|
// Inputs: data, item
|
||||||
if (args_spec_list.size() != 2) {
|
if (args_spec_list.size() != 2) {
|
||||||
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
||||||
|
|
|
@ -350,6 +350,17 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
|
||||||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
void InitUndeterminedFromEnv(const std::string &sparse_shape_types);
|
void InitUndeterminedFromEnv(const std::string &sparse_shape_types);
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -228,6 +228,10 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
||||||
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
|
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
|
||||||
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
||||||
}
|
}
|
||||||
|
if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||||
|
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
|
||||||
|
return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
|
||||||
|
}
|
||||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
|
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
|
||||||
if (func == nullptr) {
|
if (func == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString()
|
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString()
|
||||||
|
|
|
@ -32,6 +32,7 @@ using mindspore::abstract::AbstractBase;
|
||||||
using mindspore::abstract::AbstractClass;
|
using mindspore::abstract::AbstractClass;
|
||||||
using mindspore::abstract::AbstractError;
|
using mindspore::abstract::AbstractError;
|
||||||
using mindspore::abstract::AbstractFunction;
|
using mindspore::abstract::AbstractFunction;
|
||||||
|
using mindspore::abstract::AbstractIndexedSlices;
|
||||||
using mindspore::abstract::AbstractJTagged;
|
using mindspore::abstract::AbstractJTagged;
|
||||||
using mindspore::abstract::AbstractList;
|
using mindspore::abstract::AbstractList;
|
||||||
using mindspore::abstract::AbstractScalar;
|
using mindspore::abstract::AbstractScalar;
|
||||||
|
@ -93,7 +94,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
|
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
|
||||||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
|
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
|
||||||
|
ptrBase->isa<abstract::AbstractRefKey>()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -89,6 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||||
max_device_memory_ = kDefaultMaxDeviceMemory;
|
max_device_memory_ = kDefaultMaxDeviceMemory;
|
||||||
print_file_path_ = "";
|
print_file_path_ = "";
|
||||||
enable_graph_kernel_ = false;
|
enable_graph_kernel_ = false;
|
||||||
|
enable_sparse_flag_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<MsContext> MsContext::GetInstance() {
|
std::shared_ptr<MsContext> MsContext::GetInstance() {
|
||||||
|
|
|
@ -161,6 +161,9 @@ class MsContext {
|
||||||
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
|
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
|
||||||
bool enable_graph_kernel() const { return enable_graph_kernel_; }
|
bool enable_graph_kernel() const { return enable_graph_kernel_; }
|
||||||
|
|
||||||
|
bool enable_sparse_flag() const { return enable_sparse_flag_; }
|
||||||
|
void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MsContext(const std::string &backend_policy, const std::string &target);
|
MsContext(const std::string &backend_policy, const std::string &target);
|
||||||
void GetGeOptions(std::map<std::string, std::string> *ge_options) const;
|
void GetGeOptions(std::map<std::string, std::string> *ge_options) const;
|
||||||
|
@ -204,6 +207,7 @@ class MsContext {
|
||||||
float max_device_memory_;
|
float max_device_memory_;
|
||||||
std::string print_file_path_;
|
std::string print_file_path_;
|
||||||
bool enable_graph_kernel_;
|
bool enable_graph_kernel_;
|
||||||
|
bool enable_sparse_flag_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -17,10 +17,10 @@ from . import dtype
|
||||||
from .api import ms_function
|
from .api import ms_function
|
||||||
from .dtype import *
|
from .dtype import *
|
||||||
from .parameter import Parameter, ParameterTuple
|
from .parameter import Parameter, ParameterTuple
|
||||||
from .tensor import MetaTensor, Tensor
|
from .tensor import MetaTensor, Tensor, IndexedSlices
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MetaTensor", "Tensor", # tensor
|
"MetaTensor", "Tensor", "IndexedSlices", # tensor
|
||||||
'ms_function', # api
|
'ms_function', # api
|
||||||
'Parameter', 'ParameterTuple', # parameter
|
'Parameter', 'ParameterTuple', # parameter
|
||||||
"dtype"
|
"dtype"
|
||||||
|
|
|
@ -52,13 +52,16 @@ class Parameter:
|
||||||
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
|
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
|
||||||
broadcast and gradients communication would not be applied on parameters. Default: False.
|
broadcast and gradients communication would not be applied on parameters. Default: False.
|
||||||
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
|
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
|
||||||
|
has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false.
|
||||||
"""
|
"""
|
||||||
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=""):
|
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False,
|
||||||
|
sparse_grad="", has_indexed_slices_grad=False):
|
||||||
self.set_parameter_data(default_input)
|
self.set_parameter_data(default_input)
|
||||||
self.name = name
|
self.name = name
|
||||||
self.requires_grad = requires_grad
|
self.requires_grad = requires_grad
|
||||||
self.layerwise_parallel = layerwise_parallel
|
self.layerwise_parallel = layerwise_parallel
|
||||||
self.sparse_grad = sparse_grad
|
self.sparse_grad = sparse_grad
|
||||||
|
self.has_indexed_slices_grad = has_indexed_slices_grad
|
||||||
self._is_init = False
|
self._is_init = False
|
||||||
self._sliced = False
|
self._sliced = False
|
||||||
self.clone_info = _CloneInfo()
|
self.clone_info = _CloneInfo()
|
||||||
|
@ -186,6 +189,17 @@ class Parameter:
|
||||||
raise TypeError("`sparse_grad` parameter must be str type")
|
raise TypeError("`sparse_grad` parameter must be str type")
|
||||||
self._sparse_grad = value
|
self._sparse_grad = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_indexed_slices_grad(self):
|
||||||
|
"""Return whether the parameter's gradient is indexed_slices."""
|
||||||
|
return self._has_indexed_slices_grad
|
||||||
|
|
||||||
|
@has_indexed_slices_grad.setter
|
||||||
|
def has_indexed_slices_grad(self, value=False):
|
||||||
|
if not isinstance(value, bool):
|
||||||
|
raise TypeError("`has_indexed_slices_grad` parameter must be bool type")
|
||||||
|
self._has_indexed_slices_grad = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self):
|
||||||
return self.default_input
|
return self.default_input
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
|
||||||
from . import dtype as mstype
|
from . import dtype as mstype
|
||||||
from ._register_for_tensor import tensor_operator_registry
|
from ._register_for_tensor import tensor_operator_registry
|
||||||
|
|
||||||
__all__ = ['Tensor', 'MetaTensor']
|
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices']
|
||||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||||
np.float32, np.float64, np.bool_)
|
np.float32, np.float64, np.bool_)
|
||||||
|
@ -214,3 +214,8 @@ class Tensor(Tensor_):
|
||||||
raise TypeError("init_flag must be bool.")
|
raise TypeError("init_flag must be bool.")
|
||||||
self.set_init_flag(value)
|
self.set_init_flag(value)
|
||||||
self._init_flag = value
|
self._init_flag = value
|
||||||
|
|
||||||
|
|
||||||
|
class IndexedSlices:
|
||||||
|
def __init__(self, indices, values, dense_shape):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -355,6 +355,14 @@ class _Context:
|
||||||
def check_bprop(self, check_bprop_flag):
|
def check_bprop(self, check_bprop_flag):
|
||||||
self._context_handle.set_check_bprop_flag(check_bprop_flag)
|
self._context_handle.set_check_bprop_flag(check_bprop_flag)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_sparse(self):
|
||||||
|
return self._context_handle.get_enable_sparse_flag()
|
||||||
|
|
||||||
|
@enable_sparse.setter
|
||||||
|
def enable_sparse(self, enable_sparse_flag):
|
||||||
|
self._context_handle.set_enable_sparse_flag(enable_sparse_flag)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_device_memory(self):
|
def max_device_memory(self):
|
||||||
return self._context_handle.get_max_device_memory()
|
return self._context_handle.get_max_device_memory()
|
||||||
|
@ -510,7 +518,8 @@ def reset_auto_parallel_context():
|
||||||
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
|
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
|
||||||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||||
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str)
|
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
|
||||||
|
enable_sparse=bool)
|
||||||
def set_context(**kwargs):
|
def set_context(**kwargs):
|
||||||
"""
|
"""
|
||||||
Sets context for running environment.
|
Sets context for running environment.
|
||||||
|
@ -567,6 +576,7 @@ def set_context(**kwargs):
|
||||||
The format is "xxGB". Default: "1024GB".
|
The format is "xxGB". Default: "1024GB".
|
||||||
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
|
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
|
||||||
a file by default, and turn off printing to the screen.
|
a file by default, and turn off printing to the screen.
|
||||||
|
enable_sparse (bool): Whether to enable sparse feature. Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not an attribute in context.
|
ValueError: If input key is not an attribute in context.
|
||||||
|
|
|
@ -153,6 +153,14 @@ shape_mul = Primitive("shape_mul")
|
||||||
# a primitive to compare between tuple.
|
# a primitive to compare between tuple.
|
||||||
stop_gradient = Primitive("stop_gradient")
|
stop_gradient = Primitive("stop_gradient")
|
||||||
|
|
||||||
|
|
||||||
|
make_indexed_slices = Primitive('MakeIndexedSlices')
|
||||||
|
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||||
|
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||||
|
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||||
|
is_indexed_slices = Primitive('IsIndexedSlices')
|
||||||
|
|
||||||
|
|
||||||
tensor_operator_registry.register('__add__', tensor_add)
|
tensor_operator_registry.register('__add__', tensor_add)
|
||||||
tensor_operator_registry.register('__sub__', tensor_sub)
|
tensor_operator_registry.register('__sub__', tensor_sub)
|
||||||
tensor_operator_registry.register('__mul__', tensor_mul)
|
tensor_operator_registry.register('__mul__', tensor_mul)
|
||||||
|
|
|
@ -564,7 +564,7 @@ class SparseGatherV2(GatherV2):
|
||||||
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
|
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
|
||||||
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
|
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
|
||||||
>>> axis = 1
|
>>> axis = 1
|
||||||
>>> out = P.GatherV2()(input_params, input_indices, axis)
|
>>> out = P.SparseGatherV2()(input_params, input_indices, axis)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -603,5 +603,18 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
|
||||||
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
|
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
|
||||||
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
|
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestOptLib, test_indexed_slices) {
|
||||||
|
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_indices");
|
||||||
|
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_indices");
|
||||||
|
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_values");
|
||||||
|
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_values");
|
||||||
|
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_dense_shape");
|
||||||
|
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_dense_shape");
|
||||||
|
auto patterns = std::vector<SubstitutionPtr>({irpass.indexed_slices_eliminate_});
|
||||||
|
ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1130,3 +1130,38 @@ def test_adjust_allreduce_mul_add(tag):
|
||||||
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
|
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexed_slices(tag):
|
||||||
|
""" test_add_zero """
|
||||||
|
fns = FnDict()
|
||||||
|
make_indexed_slices = Primitive('MakeIndexedSlices')
|
||||||
|
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||||
|
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||||
|
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_get_indices(x, y, z):
|
||||||
|
return indexed_slices_get_indices(make_indexed_slices(x, y, z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after_get_indices(x, y, z):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_get_values(x, y, z):
|
||||||
|
return indexed_slices_get_values(make_indexed_slices(x, y, z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after_get_values(x, y, z):
|
||||||
|
return y
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_get_dense_shape(x, y, z):
|
||||||
|
return indexed_slices_get_dense_shape(make_indexed_slices(x, y, z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after_get_dense_shape(x, y, z):
|
||||||
|
return z
|
||||||
|
|
||||||
|
return fns[tag]
|
||||||
|
|
|
@ -0,0 +1,290 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
@File : test_indexed_slices.py
|
||||||
|
@Author:
|
||||||
|
@Date : 2020-06-08
|
||||||
|
@Desc : test mindspore indexed_slices's operation
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
|
from mindspore.ops._grad.grad_base import bprop_getters
|
||||||
|
from mindspore import Tensor, IndexedSlices, context
|
||||||
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from mindspore._checkparam import Rel
|
||||||
|
from mindspore.nn import Optimizer
|
||||||
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
|
|
||||||
|
reduce_sum = P.ReduceSum()
|
||||||
|
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||||
|
transpose = P.Transpose()
|
||||||
|
shape_op = P.Shape()
|
||||||
|
reshape = P.Reshape()
|
||||||
|
size_op = P.Size()
|
||||||
|
invert_permutation = P.InvertPermutation()
|
||||||
|
logical_and = P.LogicalAnd()
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _generate_shape_index(out_shape, indices_shape, axis):
|
||||||
|
out_rank = len(out_shape)
|
||||||
|
ind_rank = len(indices_shape)
|
||||||
|
if axis < 0:
|
||||||
|
axis += out_rank - ind_rank + 1
|
||||||
|
perm_part1 = tuple(range(axis, axis + ind_rank))
|
||||||
|
index = tuple(range(out_rank))
|
||||||
|
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
|
||||||
|
return perm
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _generate_inverse_index(x_shape, axis):
|
||||||
|
x_rank = len(x_shape)
|
||||||
|
index = tuple(range(x_rank))
|
||||||
|
if axis < 0:
|
||||||
|
axis += x_rank
|
||||||
|
perm = index[1:1 + axis] + (0,) + index[1 + axis:]
|
||||||
|
return perm
|
||||||
|
|
||||||
|
class MySparseGatherV2(P.GatherV2):
|
||||||
|
"""
|
||||||
|
For test
|
||||||
|
"""
|
||||||
|
|
||||||
|
@bprop_getters.register(MySparseGatherV2)
|
||||||
|
def get_bprop_sparse_gather_v2(self):
|
||||||
|
"""Generate bprop for MySparseGatherV2"""
|
||||||
|
|
||||||
|
def bprop(x, indices, axis, out, dout):
|
||||||
|
x_shp = shape_op(x)
|
||||||
|
if axis == 0:
|
||||||
|
indices_size = (size_op(indices),)
|
||||||
|
x_tail_shp = x_shp[1:]
|
||||||
|
values_shape = indices_size + x_tail_shp
|
||||||
|
values = reshape(dout, values_shape)
|
||||||
|
indices = reshape(indices, indices_size)
|
||||||
|
return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
|
||||||
|
if F.rank(dout) == 0:
|
||||||
|
dout = P.ExpandDims()(dout, -1)
|
||||||
|
if F.rank(indices) == 0:
|
||||||
|
indices = P.ExpandDims()(indices, -1)
|
||||||
|
out_shp = shape_op(dout)
|
||||||
|
ind_shp = shape_op(indices)
|
||||||
|
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
||||||
|
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
|
||||||
|
values_transpose = transpose(dout, perm_1)
|
||||||
|
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
|
||||||
|
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
||||||
|
perm_2 = _generate_inverse_index(x_shp, axis)
|
||||||
|
params_grad = transpose(params_grad, perm_2)
|
||||||
|
return params_grad, zeros_like(indices), zeros_like(axis)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
|
||||||
|
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor", "Tensor", "Undetermined", "Bool")
|
||||||
|
def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||||
|
if gradient.is_indexed_slices():
|
||||||
|
return gradient.values()
|
||||||
|
op_mul = P.Mul()
|
||||||
|
op_square = P.Square()
|
||||||
|
op_sqrt = P.Sqrt()
|
||||||
|
op_cast = P.Cast()
|
||||||
|
op_reshape = P.Reshape()
|
||||||
|
op_shape = P.Shape()
|
||||||
|
|
||||||
|
param_fp32 = op_cast(param, mstype.float32)
|
||||||
|
m_fp32 = op_cast(m, mstype.float32)
|
||||||
|
v_fp32 = op_cast(v, mstype.float32)
|
||||||
|
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||||
|
|
||||||
|
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||||
|
|
||||||
|
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||||
|
- beta2, op_square(gradient_fp32))
|
||||||
|
|
||||||
|
update = next_m / (op_sqrt(next_v) + eps)
|
||||||
|
if decay_flag:
|
||||||
|
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||||
|
|
||||||
|
update_with_lr = op_mul(lr, update)
|
||||||
|
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||||
|
|
||||||
|
next_v = F.depend(next_v, F.assign(param, next_param))
|
||||||
|
next_v = F.depend(next_v, F.assign(m, next_m))
|
||||||
|
next_v = F.depend(next_v, F.assign(v, next_v))
|
||||||
|
return next_v
|
||||||
|
|
||||||
|
|
||||||
|
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||||
|
"""Check the type of inputs."""
|
||||||
|
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||||
|
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||||
|
validator.check_value_type("eps", eps, [float], prim_name)
|
||||||
|
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||||
|
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||||
|
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||||
|
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||||
|
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||||
|
|
||||||
|
|
||||||
|
class AdamWeightDecaySparse(Optimizer):
|
||||||
|
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
|
||||||
|
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||||
|
super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
|
||||||
|
if self.is_group:
|
||||||
|
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||||
|
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||||
|
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||||
|
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||||
|
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||||
|
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
|
||||||
|
|
||||||
|
self.params = self.parameters
|
||||||
|
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||||
|
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||||
|
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||||
|
self.map = C.Map()
|
||||||
|
|
||||||
|
def construct(self, gradients):
|
||||||
|
lr = self.get_lr()
|
||||||
|
updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
|
||||||
|
self.weight_decay_tensor),
|
||||||
|
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||||
|
return updated_velocity
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexed_slices_make_indexed_slices():
|
||||||
|
class MakeIndexedSlices(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MakeIndexedSlices, self).__init__()
|
||||||
|
self.dense_shape = (3, 4)
|
||||||
|
def construct(self, indices, values):
|
||||||
|
ret = (IndexedSlices(indices, values, self.dense_shape),)
|
||||||
|
return ret[0].is_indexed_slices()
|
||||||
|
indices = Tensor([[0, 0], [1, 2]])
|
||||||
|
values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
MakeIndexedSlices()(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexed_slices_attr():
|
||||||
|
class IndexedSlicesGetAttr(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(IndexedSlicesGetAttr, self).__init__()
|
||||||
|
self.dense_shape = (3, 4)
|
||||||
|
def construct(self, indices, values):
|
||||||
|
x = IndexedSlices(indices, values, self.dense_shape)
|
||||||
|
return x.values(), x.indices(), x.dense_shape()
|
||||||
|
indices = Tensor([[0, 0], [1, 2]])
|
||||||
|
values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
IndexedSlicesGetAttr()(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexed_slices_sparse_gatherv2_grad_all():
|
||||||
|
grad_all = C.GradOperation('get_all', get_all=True)
|
||||||
|
class GradWrap(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
def construct(self, x, y):
|
||||||
|
grad = grad_all(self.network)(x, y)
|
||||||
|
return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices()
|
||||||
|
class SparseGatherV2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SparseGatherV2, self).__init__()
|
||||||
|
self.sparse_gatherv2 = MySparseGatherV2()
|
||||||
|
self.axis = 0
|
||||||
|
def construct(self, params, indices):
|
||||||
|
return self.sparse_gatherv2(params, indices, self.axis)
|
||||||
|
params = Tensor(np.ones([3, 1, 2]).astype(np.int32))
|
||||||
|
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||||
|
GradWrap(SparseGatherV2())(params, indices)
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
||||||
|
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
|
||||||
|
class GradWrap(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
|
||||||
|
def construct(self, x):
|
||||||
|
weights = self.weights
|
||||||
|
grad = grad_by_list(self.network, weights)(x)
|
||||||
|
x = grad[0]
|
||||||
|
return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape()
|
||||||
|
class SparseGatherV2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SparseGatherV2, self).__init__()
|
||||||
|
self.sparse_gatherv2 = MySparseGatherV2()
|
||||||
|
self.axis = 0
|
||||||
|
self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)),
|
||||||
|
name="params", has_indexed_slices_grad=True)
|
||||||
|
def construct(self, indices):
|
||||||
|
return self.sparse_gatherv2(self.params, indices, self.axis)
|
||||||
|
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||||
|
network = GradWrap(SparseGatherV2())
|
||||||
|
network(indices)
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexed_slices_is_indexed_slices():
|
||||||
|
class MakeIndexedSlices(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MakeIndexedSlices, self).__init__()
|
||||||
|
self.dense_shape = (3, 4)
|
||||||
|
def construct(self, indices, values):
|
||||||
|
indexed_slices = IndexedSlices(indices, values, self.dense_shape)
|
||||||
|
ret = indexed_slices.is_indexed_slices()
|
||||||
|
return ret
|
||||||
|
indices = Tensor([[0, 0], [1, 2]])
|
||||||
|
values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
MakeIndexedSlices()(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexed_slices_env_get():
|
||||||
|
class Loss(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Loss, self).__init__()
|
||||||
|
def construct(self, base, target):
|
||||||
|
return base
|
||||||
|
class NetWithSparseGatherV2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetWithSparseGatherV2, self).__init__()
|
||||||
|
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True)
|
||||||
|
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
|
||||||
|
self.gatherv2 = MySparseGatherV2()
|
||||||
|
self.axis = 0
|
||||||
|
def construct(self, indices):
|
||||||
|
return self.gatherv2(self.w1, indices, self.axis) * self.w2
|
||||||
|
|
||||||
|
inputs = Tensor(np.array([0, 1]).astype(np.int32))
|
||||||
|
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
|
||||||
|
net = NetWithSparseGatherV2()
|
||||||
|
net.set_train()
|
||||||
|
loss = Loss()
|
||||||
|
optimizer = AdamWeightDecaySparse(net.trainable_params())
|
||||||
|
|
||||||
|
net_with_loss = WithLossCell(net, loss)
|
||||||
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
|
train_network(inputs, label)
|
|
@ -155,7 +155,7 @@ def test_AdamWeightDecaySparse():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NetWithSparseGatherV2, self).__init__()
|
super(NetWithSparseGatherV2, self).__init__()
|
||||||
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1")
|
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1")
|
||||||
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2", sparse_grad="sparse_key_w2")
|
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
|
||||||
self.gatherv2 = P.SparseGatherV2()
|
self.gatherv2 = P.SparseGatherV2()
|
||||||
self.axis = 0
|
self.axis = 0
|
||||||
def construct(self, indices):
|
def construct(self, indices):
|
||||||
|
|
Loading…
Reference in New Issue