forked from mindspore-Ecosystem/mindspore
Add IndexedSlices
This commit is contained in:
parent
d454daec1b
commit
d6635bbbe2
|
@ -17,6 +17,7 @@
|
|||
"""Resources for ast tree parse."""
|
||||
import ast
|
||||
import math
|
||||
from mindspore import IndexedSlices
|
||||
from mindspore.ops.composite import multitype_ops
|
||||
from mindspore.ops import functional as F, composite as C
|
||||
from . import standard_method as M
|
||||
|
@ -135,4 +136,7 @@ convert_object_map = {
|
|||
math.sin: NO_IMPLEMENT,
|
||||
math.cos: NO_IMPLEMENT,
|
||||
math.tan: NO_IMPLEMENT,
|
||||
|
||||
# user defined
|
||||
IndexedSlices: F.make_indexed_slices,
|
||||
}
|
||||
|
|
|
@ -120,6 +120,10 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
|
|||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
|
||||
}
|
||||
}
|
||||
} else if (type->isa<IndexedSlicesType>()) {
|
||||
// Do Nothing
|
||||
} else if (type->isa<UndeterminedType>()) {
|
||||
// Do Nothing
|
||||
} else if (type->isa<Tuple>()) {
|
||||
TuplePtr tuple_type = dyn_cast<Tuple>(type);
|
||||
type_proto->set_data_type(irpb::DT_TUPLE);
|
||||
|
|
|
@ -94,6 +94,48 @@ bool Slice::operator==(const Type &other) const {
|
|||
|
||||
std::string Slice::DumpText() const { return ToString(); }
|
||||
|
||||
TypePtr UndeterminedType::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(element_type_);
|
||||
if (IsGeneric()) {
|
||||
return std::make_shared<UndeterminedType>();
|
||||
}
|
||||
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
|
||||
}
|
||||
|
||||
std::string UndeterminedType::ToReprString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "Undetermined";
|
||||
}
|
||||
return "Undetermined[" + element_type_->ToReprString() + "]";
|
||||
}
|
||||
|
||||
std::string UndeterminedType::ToString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "Undetermined";
|
||||
}
|
||||
return "Undetermined[" + element_type_->ToString() + "]";
|
||||
}
|
||||
|
||||
std::string UndeterminedType::DumpText() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "Undetermined";
|
||||
}
|
||||
return "Undetermined[" + element_type_->DumpText() + "]";
|
||||
}
|
||||
|
||||
bool UndeterminedType::operator==(const Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
|
||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||
return true;
|
||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
|
||||
TypePtr TensorType::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(element_type_);
|
||||
if (IsGeneric()) {
|
||||
|
@ -137,6 +179,48 @@ bool TensorType::operator==(const Type &other) const {
|
|||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
|
||||
TypePtr IndexedSlicesType::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(element_type_);
|
||||
if (IsGeneric()) {
|
||||
return std::make_shared<IndexedSlicesType>();
|
||||
}
|
||||
return std::make_shared<IndexedSlicesType>(element_type_->DeepCopy());
|
||||
}
|
||||
|
||||
std::string IndexedSlicesType::ToReprString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "IndexedSlices";
|
||||
}
|
||||
return "IndexedSlices[" + element_type_->ToReprString() + "]";
|
||||
}
|
||||
|
||||
std::string IndexedSlicesType::ToString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "IndexedSlices";
|
||||
}
|
||||
return "IndexedSlices[" + element_type_->ToString() + "]";
|
||||
}
|
||||
|
||||
std::string IndexedSlicesType::DumpText() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "IndexedSlices";
|
||||
}
|
||||
return "IndexedSlices[" + element_type_->DumpText() + "]";
|
||||
}
|
||||
|
||||
bool IndexedSlicesType::operator==(const Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_elem_type = static_cast<const IndexedSlicesType &>(other).element_type_;
|
||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||
return true;
|
||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
|
||||
Function::Function() : Object(kObjectTypeFunction) {
|
||||
args_ = std::vector<TypePtr>();
|
||||
retval_ = nullptr;
|
||||
|
|
|
@ -108,10 +108,34 @@ class Slice : public Object {
|
|||
};
|
||||
using SlicePtr = std::shared_ptr<Slice>;
|
||||
|
||||
class UndeterminedType : public Object {
|
||||
public:
|
||||
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
|
||||
explicit UndeterminedType(const TypePtr &ele)
|
||||
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
|
||||
~UndeterminedType() override = default;
|
||||
MS_DECLARE_PARENT(UndeterminedType, Object)
|
||||
|
||||
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
|
||||
const TypePtr element() const { return element_type_; }
|
||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||
|
||||
TypePtr DeepCopy() const override;
|
||||
std::string ToString() const override;
|
||||
std::string ToReprString() const override;
|
||||
std::string DumpText() const override;
|
||||
bool operator==(const Type &other) const override;
|
||||
|
||||
protected:
|
||||
TypePtr element_type_;
|
||||
};
|
||||
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
|
||||
|
||||
class TensorType : public Object {
|
||||
public:
|
||||
TensorType() : Object(kObjectTypeTensorType) {}
|
||||
explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
|
||||
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
|
||||
explicit TensorType(const TypePtr &ele)
|
||||
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||
~TensorType() override = default;
|
||||
MS_DECLARE_PARENT(TensorType, Object)
|
||||
|
||||
|
@ -130,6 +154,29 @@ class TensorType : public Object {
|
|||
};
|
||||
using TensorTypePtr = std::shared_ptr<TensorType>;
|
||||
|
||||
class IndexedSlicesType : public Object {
|
||||
public:
|
||||
IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {}
|
||||
explicit IndexedSlicesType(const TypePtr &ele)
|
||||
: Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||
~IndexedSlicesType() override = default;
|
||||
MS_DECLARE_PARENT(IndexedSlicesType, Object)
|
||||
|
||||
TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; }
|
||||
const TypePtr element() const { return element_type_; }
|
||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||
|
||||
TypePtr DeepCopy() const override;
|
||||
std::string ToString() const override;
|
||||
std::string ToReprString() const override;
|
||||
std::string DumpText() const override;
|
||||
bool operator==(const Type &other) const override;
|
||||
|
||||
private:
|
||||
TypePtr element_type_;
|
||||
};
|
||||
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
|
||||
|
||||
class Function : public Object {
|
||||
public:
|
||||
Function();
|
||||
|
@ -255,6 +302,8 @@ TypePtr StringToType(const std::string &type_name);
|
|||
// Judge whether x is predicate or is a subclass of predicate.
|
||||
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
|
||||
|
||||
bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type);
|
||||
|
||||
// Whether t1 is identity or a subclass of t2.
|
||||
bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
|
||||
|
||||
|
|
|
@ -115,6 +115,10 @@ const char *ObjectIdLabel(const TypeId &v) {
|
|||
return "kObjectTypeKeyword";
|
||||
case kObjectTypeTensorType:
|
||||
return "kObjectTypeTensorType";
|
||||
case kObjectTypeIndexedSlicesType:
|
||||
return "kObjectTypeIndexedSlicesType";
|
||||
case kObjectTypeUndeterminedType:
|
||||
return "kObjectTypeUndeterminedType";
|
||||
case kObjectTypeDictionary:
|
||||
return "kObjectTypeDictionary";
|
||||
case kObjectTypeClass:
|
||||
|
|
|
@ -67,6 +67,7 @@ class Type : public Value {
|
|||
virtual bool equal(const TypePtr other) const { return *this == *other; }
|
||||
|
||||
virtual TypeId object_type() const { return kTypeUnknown; }
|
||||
virtual TypeId parent_type() const { return kTypeUnknown; }
|
||||
virtual TypeId number_type() const { return kTypeUnknown; }
|
||||
virtual TypePtr DeepCopy() const = 0;
|
||||
virtual TypePtr Clone() const { return DeepCopy(); }
|
||||
|
@ -97,13 +98,16 @@ using TypePtrList = std::vector<TypePtr>;
|
|||
//
|
||||
class Object : public Type {
|
||||
public:
|
||||
Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject) {}
|
||||
Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {}
|
||||
explicit Object(const TypeId object_type, bool is_generic = true)
|
||||
: Type(kMetaTypeObject, is_generic), object_type_(object_type) {}
|
||||
: Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {}
|
||||
explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true)
|
||||
: Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {}
|
||||
~Object() override = default;
|
||||
MS_DECLARE_PARENT(Object, Type)
|
||||
|
||||
TypeId object_type() const override { return object_type_; }
|
||||
TypeId parent_type() const override { return parent_type_; }
|
||||
TypeId type_id() const override { return object_type_; }
|
||||
TypeId generic_type_id() const override { return kMetaTypeObject; }
|
||||
bool equal(const TypePtr other) const override;
|
||||
|
@ -114,6 +118,7 @@ class Object : public Type {
|
|||
|
||||
private:
|
||||
const TypeId object_type_;
|
||||
const TypeId parent_type_;
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
|
||||
|
|
|
@ -50,6 +50,8 @@ enum TypeId : int {
|
|||
kObjectTypeSlice,
|
||||
kObjectTypeKeyword,
|
||||
kObjectTypeTensorType,
|
||||
kObjectTypeIndexedSlicesType,
|
||||
kObjectTypeUndeterminedType,
|
||||
kObjectTypeClass,
|
||||
kObjectTypeDictionary,
|
||||
kObjectTypeFunction,
|
||||
|
|
|
@ -192,6 +192,40 @@ TypePtr TensorStrToType(const std::string &type_name) {
|
|||
return type;
|
||||
}
|
||||
|
||||
TypePtr IndexedSlicesStrToType(const std::string &type_name) {
|
||||
if (type_name == "IndexedSlices") {
|
||||
return std::make_shared<IndexedSlicesType>();
|
||||
}
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto element_str = type_name.substr(start, end - start);
|
||||
auto element_type = StringToType(element_str);
|
||||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<IndexedSlicesType>(element_type);
|
||||
}
|
||||
|
||||
TypePtr UndeterminedStrToType(const std::string &type_name) {
|
||||
if (type_name == "Undetermined") {
|
||||
return std::make_shared<UndeterminedType>();
|
||||
}
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto element_str = type_name.substr(start, end - start);
|
||||
auto element_type = StringToType(element_str);
|
||||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<UndeterminedType>(element_type);
|
||||
}
|
||||
|
||||
TypePtr ListStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "List") {
|
||||
|
@ -313,6 +347,10 @@ TypePtr StringToType(const std::string &type_name) {
|
|||
type = StringToNumberType<Float>(type_name, "Float");
|
||||
} else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
|
||||
type = TensorStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) {
|
||||
type = UndeterminedStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
|
||||
type = IndexedSlicesStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
||||
type = ListStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
||||
|
@ -340,6 +378,20 @@ TypePtr StringToType(const std::string &type_name) {
|
|||
return type;
|
||||
}
|
||||
|
||||
bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) {
|
||||
if (x == nullptr || base_type == nullptr) {
|
||||
MS_LOG(ERROR) << "Type is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
|
||||
return false;
|
||||
}
|
||||
if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
||||
if (x == nullptr || base_type == nullptr) {
|
||||
MS_LOG(ERROR) << "Type is nullptr.";
|
||||
|
@ -481,6 +533,10 @@ REGISTER_PYBIND_DEFINE(
|
|||
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
|
||||
.def(py::init());
|
||||
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
|
||||
.def(py::init());
|
||||
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
||||
.def(py::init())
|
||||
.def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval"));
|
||||
|
@ -501,6 +557,8 @@ const TypePtr kTypeExternal = std::make_shared<External>();
|
|||
const TypePtr kTypeEnv = std::make_shared<EnvType>();
|
||||
const TypePtr kTypeType = std::make_shared<TypeType>();
|
||||
const TypePtr kTensorType = std::make_shared<TensorType>();
|
||||
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
|
||||
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
|
||||
const TypePtr kString = std::make_shared<String>();
|
||||
const TypePtr kList = std::make_shared<List>();
|
||||
const TypePtr kTuple = std::make_shared<Tuple>();
|
||||
|
|
|
@ -93,15 +93,17 @@ static TypePtr UnwrapRef(const TypePtr &type) {
|
|||
}
|
||||
return type;
|
||||
}
|
||||
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
||||
bool find_fn = false;
|
||||
py::function py_fn;
|
||||
|
||||
// Return Exact match if exists, else return non ambiguous sub class match
|
||||
// Return py::none() if matching is ambiguous
|
||||
const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
|
||||
// Exact match
|
||||
for (auto &item : fn_cache_py_) {
|
||||
TypePtrList sign = item.first;
|
||||
if (sign.size() != types.size()) {
|
||||
continue;
|
||||
}
|
||||
bool match = true;
|
||||
auto match = true;
|
||||
for (size_t i = 0; i < sign.size(); ++i) {
|
||||
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
|
||||
match = false;
|
||||
|
@ -111,13 +113,45 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
|||
if (!match) {
|
||||
continue;
|
||||
}
|
||||
find_fn = true;
|
||||
py_fn = item.second;
|
||||
break;
|
||||
return item.second;
|
||||
}
|
||||
// Try best match
|
||||
py::function py_fn_subclass;
|
||||
size_t subclass_match_cnt = 0;
|
||||
for (auto &item : fn_cache_py_) {
|
||||
TypePtrList sign = item.first;
|
||||
if (sign.size() != types.size()) {
|
||||
continue;
|
||||
}
|
||||
auto match = true;
|
||||
for (size_t i = 0; i < sign.size(); ++i) {
|
||||
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) &&
|
||||
!IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) {
|
||||
continue;
|
||||
}
|
||||
py_fn_subclass = item.second;
|
||||
subclass_match_cnt++;
|
||||
}
|
||||
if (subclass_match_cnt > 1) {
|
||||
MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass";
|
||||
}
|
||||
if (subclass_match_cnt == 1) {
|
||||
MS_LOG(DEBUG) << "Found one subclass match";
|
||||
return py_fn_subclass;
|
||||
}
|
||||
return py::none();
|
||||
}
|
||||
|
||||
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
||||
auto py_fn = SignMatch(types);
|
||||
std::ostringstream buffer;
|
||||
buffer << types;
|
||||
if (find_fn) {
|
||||
if (py_fn != py::none()) {
|
||||
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
|
||||
|
|
|
@ -54,6 +54,7 @@ class MultitypeFuncGraph : public MetaFuncGraph {
|
|||
}
|
||||
|
||||
private:
|
||||
const py::function SignMatch(const TypePtrList &types);
|
||||
std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_;
|
||||
std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_;
|
||||
};
|
||||
|
|
|
@ -277,5 +277,12 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary
|
|||
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
||||
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
||||
const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
|
||||
|
||||
// IndexedSlices
|
||||
const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices");
|
||||
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
|
||||
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
|
||||
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
|
||||
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -287,6 +287,13 @@ extern const PrimitivePtr kPrimMirror;
|
|||
extern const PrimitivePtr kPrimVirtualDiv;
|
||||
extern const PrimitivePtr kPrimVirtualDataset;
|
||||
|
||||
// IndexedSlices
|
||||
extern const PrimitivePtr kPrimMakeIndexedSlices;
|
||||
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
|
||||
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
|
||||
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
|
||||
extern const PrimitivePtr kPrimIsIndexedSlices;
|
||||
|
||||
class DoSignaturePrimitive : public Primitive {
|
||||
public:
|
||||
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "pipeline/static_analysis/prim.h"
|
||||
#include "pipeline/static_analysis/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -173,6 +174,13 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
return std::make_shared<AbstractTuple>(sparse_list);
|
||||
}
|
||||
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) {
|
||||
auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
|
||||
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
|
||||
}
|
||||
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
|
||||
return dflt;
|
||||
}
|
||||
|
@ -236,6 +244,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
}
|
||||
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad());
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -437,5 +446,72 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
|
||||
|
||||
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||
auto shp = dense_shape_value->value();
|
||||
std::vector<int> dense_shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
|
||||
[](const ValuePtr &e) -> int {
|
||||
auto elem = GetValue<int>(e);
|
||||
return elem;
|
||||
});
|
||||
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
||||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(dense_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(indexed_slices->values());
|
||||
return indexed_slices->values();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(indexed_slices->indices());
|
||||
return indexed_slices->indices();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape());
|
||||
return indexed_slices->dense_shape();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
bool ret = false;
|
||||
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) {
|
||||
ret = true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString();
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,6 +36,7 @@ using mindspore::abstract::AbstractJTagged;
|
|||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractUndetermined;
|
||||
|
||||
static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
|
||||
if (t == nullptr) {
|
||||
|
@ -78,7 +79,7 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(cons);
|
||||
|
||||
auto dt = data->abstract();
|
||||
if (dt == nullptr) {
|
||||
if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include "optimizer/irpass/tile_eliminate.h"
|
||||
#include "optimizer/irpass/transpose_eliminate.h"
|
||||
#include "optimizer/opt.h"
|
||||
#include "optimizer/irpass/indexed_slices_eliminate.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -153,6 +154,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// Mark interface fusion
|
||||
mark_interface_fusion_ =
|
||||
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
|
||||
|
||||
// IndexedSlices Eliminate
|
||||
indexed_slices_eliminate_ = MakeSubstitution(
|
||||
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
|
||||
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
|
||||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
|
|
|
@ -104,6 +104,9 @@ class OptimizeIRPassLib {
|
|||
|
||||
// Fusion
|
||||
SubstitutionPtr mark_interface_fusion_;
|
||||
|
||||
// IndexedSlices Eliminate
|
||||
SubstitutionPtr indexed_slices_eliminate_;
|
||||
};
|
||||
|
||||
// the collection of irpass for resolve action
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||
// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||
// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||
class IndexedSlicesEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(1);
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(2);
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(3);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) {
|
||||
tuple_ = cnode;
|
||||
is_match_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
tuple_ = nullptr;
|
||||
is_match_ = false;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_match_{false};
|
||||
CNodePtr tuple_{nullptr};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
|
@ -232,6 +232,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
auto sparse_grad =
|
||||
py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad"));
|
||||
ptr->set_sparse_grad(sparse_grad);
|
||||
auto has_indexed_slices_grad =
|
||||
py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "has_indexed_slices_grad"));
|
||||
ptr->set_has_indexed_slices_grad(has_indexed_slices_grad);
|
||||
|
||||
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
||||
args_spec.push_back(ptr);
|
||||
|
|
|
@ -154,7 +154,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
|
||||
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
|
||||
"Set the GraphKernel switch to on or off.")
|
||||
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.");
|
||||
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
|
||||
.def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.")
|
||||
.def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
|
|
@ -156,6 +156,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.replace_refkey_by_param_,
|
||||
irpass.make_ref_eliminate_,
|
||||
irpass.get_ref_param_eliminate_,
|
||||
irpass.indexed_slices_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"b_1", b_1},
|
||||
|
|
|
@ -33,148 +33,157 @@ namespace mindspore {
|
|||
namespace pipeline {
|
||||
|
||||
MethodMap &GetMethodMap() {
|
||||
static MethodMap method_map = {{kObjectTypeString,
|
||||
{
|
||||
{"__bool__", std::string("str_bool")} // C.str_bool
|
||||
}},
|
||||
{kMetaTypeNone,
|
||||
{
|
||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
}},
|
||||
{kNumberTypeBool,
|
||||
{
|
||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||
{"__or__", prim::kPrimBoolOr}, // P.bool_or
|
||||
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
|
||||
{"__ne__", std::string("bool_ne")}, // C.bool_ne
|
||||
{"__bool__", prim::kPrimIdentity} // P.identity
|
||||
}},
|
||||
{kNumberTypeInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
|
||||
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
|
||||
}},
|
||||
{kNumberTypeUInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kNumberTypeFloat,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
|
||||
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
|
||||
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("float_bool")}, // C.float_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kObjectTypeTuple,
|
||||
{
|
||||
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
|
||||
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
|
||||
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
|
||||
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
||||
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
|
||||
}},
|
||||
{kObjectTypeList,
|
||||
{
|
||||
{"__len__", prim::kPrimListLen}, // P.list_len,
|
||||
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
|
||||
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
|
||||
{"__ms_next__", std::string("list_next")}, // C.list_next
|
||||
{"append", std::string("list_append")}, // C.list_next
|
||||
{"__bool__", std::string("list_bool")}, // C.list_bool
|
||||
{"__ms_hasnext__", std::string("list_hasnext")},
|
||||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
||||
}},
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||
{"__mod__", std::string("mod")}, // C.mod
|
||||
{"__pow__", std::string("pow_")}, // C.pow
|
||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||
{"__eq__", std::string("eq")}, // C.eq
|
||||
{"__ne__", std::string("ne")}, // C.ne
|
||||
{"__lt__", std::string("lt")}, // C.lt
|
||||
{"__gt__", std::string("gt")}, // C.gt
|
||||
{"__le__", std::string("le")}, // C.le
|
||||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"__matmul__", prim::kPrimDot}, // P.dot,
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
{kObjectTypeEnvType, {}}};
|
||||
static MethodMap method_map = {
|
||||
{kObjectTypeString,
|
||||
{
|
||||
{"__bool__", std::string("str_bool")} // C.str_bool
|
||||
}},
|
||||
{kMetaTypeNone,
|
||||
{
|
||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
}},
|
||||
{kNumberTypeBool,
|
||||
{
|
||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||
{"__or__", prim::kPrimBoolOr}, // P.bool_or
|
||||
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
|
||||
{"__ne__", std::string("bool_ne")}, // C.bool_ne
|
||||
{"__bool__", prim::kPrimIdentity} // P.identity
|
||||
}},
|
||||
{kNumberTypeInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
|
||||
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
|
||||
}},
|
||||
{kNumberTypeUInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kNumberTypeFloat,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
|
||||
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
|
||||
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("float_bool")}, // C.float_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kObjectTypeTuple,
|
||||
{
|
||||
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
|
||||
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
|
||||
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
|
||||
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
||||
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
|
||||
}},
|
||||
{kObjectTypeList,
|
||||
{
|
||||
{"__len__", prim::kPrimListLen}, // P.list_len,
|
||||
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
|
||||
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
|
||||
{"__ms_next__", std::string("list_next")}, // C.list_next
|
||||
{"append", std::string("list_append")}, // C.list_next
|
||||
{"__bool__", std::string("list_bool")}, // C.list_bool
|
||||
{"__ms_hasnext__", std::string("list_hasnext")},
|
||||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
||||
}},
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||
{"__mod__", std::string("mod")}, // C.mod
|
||||
{"__pow__", std::string("pow_")}, // C.pow
|
||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||
{"__eq__", std::string("eq")}, // C.eq
|
||||
{"__ne__", std::string("ne")}, // C.ne
|
||||
{"__lt__", std::string("lt")}, // C.lt
|
||||
{"__gt__", std::string("gt")}, // C.gt
|
||||
{"__le__", std::string("le")}, // C.le
|
||||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"__matmul__", prim::kPrimDot}, // P.dot,
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
|
||||
}},
|
||||
{kObjectTypeIndexedSlicesType,
|
||||
{
|
||||
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
|
||||
{"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values
|
||||
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
|
||||
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
{kObjectTypeEnvType, {}}};
|
||||
return method_map;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,10 @@ bool AbstractBase::operator==(const AbstractBase &other) const {
|
|||
if (tid() != other.tid()) {
|
||||
return false;
|
||||
}
|
||||
if (BuildType()->type_id() == kObjectTypeUndeterminedType &&
|
||||
other.BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
return true;
|
||||
}
|
||||
if (value_ == nullptr || other.value_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: "
|
||||
<< this->ToString() << ", other: " << other.ToString();
|
||||
|
@ -65,7 +69,7 @@ std::string AbstractBase::ToString() const {
|
|||
MS_EXCEPTION_IF_NULL(shape_);
|
||||
buffer << type_name() << "("
|
||||
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
|
||||
<< " sparse_grad: " << sparse_grad_ << ")";
|
||||
<< " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
@ -76,6 +80,7 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
|||
if (*this == *other) {
|
||||
auto ret = shared_from_base<AbstractBase>();
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return ret;
|
||||
}
|
||||
auto value_self = GetValueTrack();
|
||||
|
@ -85,10 +90,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
|||
if (res_value == value_self) {
|
||||
auto ret = shared_from_base<AbstractBase>();
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return ret;
|
||||
}
|
||||
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -409,6 +416,14 @@ std::size_t AbstractSlice::hash() const {
|
|||
return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
|
||||
}
|
||||
|
||||
ShapePtr AbstractUndetermined::shape() const {
|
||||
auto shp = dyn_cast<Shape>(GetShapeTrack());
|
||||
if (shp == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Tensor should have a shape.";
|
||||
}
|
||||
return shp;
|
||||
}
|
||||
|
||||
TypePtr AbstractTensor::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(element_);
|
||||
TypePtr element_type = element_->BuildType();
|
||||
|
@ -425,6 +440,13 @@ BaseShapePtr AbstractTensor::BuildShape() const {
|
|||
}
|
||||
|
||||
AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
||||
if (other->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
auto other_tensor = dyn_cast<AbstractUndetermined>(other);
|
||||
auto element = element_->Join(other_tensor->element());
|
||||
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
|
||||
auto ret = std::make_shared<AbstractUndetermined>(element, shape);
|
||||
return ret;
|
||||
}
|
||||
auto other_tensor = dyn_cast<AbstractTensor>(other);
|
||||
if (other_tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
|
@ -433,6 +455,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
|||
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
|
||||
auto ret = std::make_shared<AbstractTensor>(element, shape);
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -474,6 +497,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
|
|||
clone->set_shape(shp->Clone());
|
||||
clone->set_value(GetValueTrack());
|
||||
clone->set_sparse_grad(sparse_grad());
|
||||
clone->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
@ -484,6 +508,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
|
|||
broaden->set_shape(shp->Clone());
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_sparse_grad(sparse_grad());
|
||||
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
|
@ -495,17 +520,10 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
|
|||
broaden->set_shape(shp);
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_sparse_grad(sparse_grad());
|
||||
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
ShapePtr AbstractTensor::shape() const {
|
||||
auto shp = dyn_cast<Shape>(GetShapeTrack());
|
||||
if (shp == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Tensor should have a shape.";
|
||||
}
|
||||
return shp;
|
||||
}
|
||||
|
||||
std::string AbstractTensor::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
BaseShapePtr shape_track = GetShapeTrack();
|
||||
|
@ -516,7 +534,7 @@ std::string AbstractTensor::ToString() const {
|
|||
buffer << type_name() << "("
|
||||
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
|
||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
|
||||
<< ")";
|
||||
<< " has_indexed_slices_grad " << has_indexed_slices_grad() << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
@ -1019,5 +1037,64 @@ std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &arg
|
|||
bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
|
||||
return AbstractBasePtrListDeepEqual(lhs, rhs);
|
||||
}
|
||||
|
||||
// IndexedSlices
|
||||
TypePtr AbstractIndexedSlices::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
TypePtr element_type = element()->BuildType();
|
||||
return std::make_shared<IndexedSlicesType>(element_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractIndexedSlices::Clone() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto clone = std::make_shared<AbstractIndexedSlices>(element()->Clone());
|
||||
ShapePtr shp = shape();
|
||||
clone->set_shape(shp->Clone());
|
||||
clone->set_value(GetValueTrack());
|
||||
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||
return clone;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractIndexedSlices::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden());
|
||||
auto shp = shape();
|
||||
broaden->set_shape(shp->Clone());
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden());
|
||||
auto shp = shape()->Clone();
|
||||
shp->Broaden();
|
||||
broaden->set_shape(shp);
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
std::string AbstractIndexedSlices::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
BaseShapePtr shape_track = GetShapeTrack();
|
||||
MS_EXCEPTION_IF_NULL(shape_track);
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto value_track = GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
buffer << type_name() << "("
|
||||
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
|
||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
|
||||
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
|
||||
<< ", dense_shape: " << dense_shape_->ToString();
|
||||
return buffer.str();
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,7 +44,7 @@ class AbstractBase : public Base {
|
|||
public:
|
||||
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
|
||||
const BaseShapePtr &shape = kNoShape)
|
||||
: value_(value), type_(type), shape_(shape), sparse_grad_("") {}
|
||||
: value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {}
|
||||
~AbstractBase() override = default;
|
||||
MS_DECLARE_PARENT(AbstractBase, Base)
|
||||
|
||||
|
@ -54,12 +54,16 @@ class AbstractBase : public Base {
|
|||
virtual bool operator==(const AbstractBase &other) const;
|
||||
void set_value(const ValuePtr &value) { value_ = value; }
|
||||
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
|
||||
void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) {
|
||||
has_indexed_slices_grad_ = has_indexed_slices_grad;
|
||||
}
|
||||
void set_type(const TypePtr &type) { type_ = type; }
|
||||
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
|
||||
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
|
||||
const std::string &value_desc() const { return value_desc_; }
|
||||
ValuePtr GetValueTrack() const { return value_; }
|
||||
const std::string &sparse_grad() const { return sparse_grad_; }
|
||||
const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
|
||||
TypePtr GetTypeTrack() const { return type_; }
|
||||
BaseShapePtr GetShapeTrack() const { return shape_; }
|
||||
|
||||
|
@ -88,6 +92,7 @@ class AbstractBase : public Base {
|
|||
BaseShapePtr shape_;
|
||||
std::string value_desc_; // store initial value description for error report
|
||||
std::string sparse_grad_;
|
||||
bool has_indexed_slices_grad_;
|
||||
};
|
||||
|
||||
class AbstractScalar : public AbstractBase {
|
||||
|
@ -231,35 +236,49 @@ class AbstractKeywordArg : public AbstractBase {
|
|||
};
|
||||
using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
|
||||
|
||||
class AbstractTensor : public AbstractBase {
|
||||
class AbstractUndetermined : public AbstractBase {
|
||||
public:
|
||||
// shape and type are all unknown
|
||||
AbstractUndetermined() : AbstractBase(kAnyValue) {}
|
||||
// only element_ and value, shape track are valid member, type track are unknown.
|
||||
explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
: AbstractBase(kAnyValue), element_(element) {
|
||||
if (element == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "element is nullptr";
|
||||
}
|
||||
if (element->isa<AbstractTensor>()) {
|
||||
if (element->isa<AbstractUndetermined>()) {
|
||||
MS_LOG(EXCEPTION) << "element type error";
|
||||
}
|
||||
set_shape(shape);
|
||||
}
|
||||
AbstractTensor(const TypePtr &element_type, const std::vector<int> &shape)
|
||||
AbstractUndetermined(const TypePtr &element_type, const std::vector<int> &shape)
|
||||
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
|
||||
if (element_type == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "element_type is nullptr";
|
||||
}
|
||||
set_shape(std::make_shared<Shape>(shape));
|
||||
}
|
||||
explicit AbstractTensor(const tensor::TensorPtr &tensor)
|
||||
: AbstractBase(tensor), element_(std::make_shared<AbstractScalar>(kAnyValue, tensor->Dtype())) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "tensor is nullptr";
|
||||
}
|
||||
set_shape(std::make_shared<Shape>(tensor->shape()));
|
||||
}
|
||||
~AbstractUndetermined() override = default;
|
||||
MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase)
|
||||
TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); }
|
||||
AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); }
|
||||
const AbstractBasePtr element() const { return element_; }
|
||||
ShapePtr shape() const;
|
||||
|
||||
protected:
|
||||
AbstractBasePtr element_;
|
||||
};
|
||||
|
||||
class AbstractTensor : public AbstractUndetermined {
|
||||
public:
|
||||
// only element_ and value, shape track are valid member, type track are unknown.
|
||||
explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
: AbstractUndetermined(element, shape) {}
|
||||
AbstractTensor(const TypePtr &element_type, const std::vector<int> &shape)
|
||||
: AbstractUndetermined(element_type, shape) {}
|
||||
explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {}
|
||||
~AbstractTensor() override = default;
|
||||
MS_DECLARE_PARENT(AbstractTensor, AbstractBase)
|
||||
MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined)
|
||||
|
||||
TypePtr BuildType() const override;
|
||||
BaseShapePtr BuildShape() const override;
|
||||
|
@ -271,9 +290,7 @@ class AbstractTensor : public AbstractBase {
|
|||
bool operator==(const AbstractTensor &other) const;
|
||||
bool operator==(const AbstractBase &other) const override;
|
||||
|
||||
ShapePtr shape() const;
|
||||
std::string ToString() const override;
|
||||
const AbstractBasePtr element() const { return element_; }
|
||||
std::size_t hash() const override {
|
||||
auto value = GetValueTrack();
|
||||
auto hash_sum = hash_combine(tid(), element_->hash());
|
||||
|
@ -285,9 +302,6 @@ class AbstractTensor : public AbstractBase {
|
|||
}
|
||||
return hash_sum;
|
||||
}
|
||||
|
||||
private:
|
||||
AbstractBasePtr element_;
|
||||
};
|
||||
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
|
||||
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
|
||||
|
@ -585,6 +599,35 @@ struct AbstractBasePtrListEqual {
|
|||
|
||||
std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list);
|
||||
bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
|
||||
|
||||
// IndexedSlices
|
||||
class AbstractIndexedSlices : public AbstractUndetermined {
|
||||
public:
|
||||
explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
: AbstractUndetermined(element, shape) {}
|
||||
AbstractIndexedSlices(const TypePtr &element_type, const std::vector<int> &shape)
|
||||
: AbstractUndetermined(element_type, shape) {}
|
||||
~AbstractIndexedSlices() override = default;
|
||||
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
|
||||
|
||||
const AbstractTensorPtr indices() const { return indices_; }
|
||||
const AbstractTensorPtr values() const { return values_; }
|
||||
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
||||
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
||||
void set_values(const AbstractTensorPtr &values) { values_ = values; }
|
||||
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||
TypePtr BuildType() const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
AbstractBasePtr BroadenWithShape() const;
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
AbstractTensorPtr indices_;
|
||||
AbstractTensorPtr values_;
|
||||
AbstractTuplePtr dense_shape_;
|
||||
};
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_
|
||||
|
|
|
@ -58,6 +58,20 @@ class Evaluator : public Base {
|
|||
return args_spec_list;
|
||||
}
|
||||
|
||||
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
|
||||
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
|
||||
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (is_abstract) {
|
||||
MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result";
|
||||
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string ToString() const override { return identifier_; }
|
||||
|
||||
virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }
|
||||
|
|
|
@ -66,6 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function)
|
|||
ABSTRACT_REPORT_NAME_TRAITS(Type)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "pipeline/parse/resolve.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
#include "pipeline/static_analysis/param_validator.h"
|
||||
#include "common/utils.h"
|
||||
|
@ -132,6 +133,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
||||
// Debug
|
||||
{prim::kPrimDebug, {InferImplDebug, true}},
|
||||
// IndexedSlices
|
||||
{prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}},
|
||||
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
|
||||
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
|
||||
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
|
||||
{prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}},
|
||||
};
|
||||
return prim_eval_implement_map;
|
||||
}
|
||||
|
@ -139,6 +146,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
using mindspore::parse::PyObjectWrapper;
|
||||
|
||||
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
}
|
||||
prim_->BeginRecordAddAttr();
|
||||
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
||||
prim_->EndRecordAddAttr();
|
||||
|
@ -485,6 +502,16 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
|||
} // end anonymous namespace
|
||||
|
||||
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
||||
|
||||
const auto &iter = cache_->find(args);
|
||||
|
@ -512,6 +539,16 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
|
|||
}
|
||||
|
||||
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
}
|
||||
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
|
||||
if (nargs_ != args.size()) {
|
||||
MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
|
||||
|
@ -871,6 +908,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
auto ref_value = ref_abs->ref();
|
||||
MS_EXCEPTION_IF_NULL(ref_value);
|
||||
ret->set_sparse_grad(ref_value->sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad());
|
||||
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
|
@ -886,6 +924,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
||||
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
||||
abs_scalar->set_sparse_grad(x->sparse_grad());
|
||||
abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad());
|
||||
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
};
|
||||
|
@ -897,6 +936,16 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
|||
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag) {
|
||||
auto ret_abstract = AbstractEval(args_spec_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
}
|
||||
// Inputs: data, item
|
||||
if (args_spec_list.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
||||
|
|
|
@ -350,6 +350,17 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
|
|||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
void InitUndeterminedFromEnv(const std::string &sparse_shape_types);
|
||||
|
||||
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -228,6 +228,10 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
|
||||
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
||||
}
|
||||
if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
|
||||
return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
|
||||
if (func == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString()
|
||||
|
|
|
@ -32,6 +32,7 @@ using mindspore::abstract::AbstractBase;
|
|||
using mindspore::abstract::AbstractClass;
|
||||
using mindspore::abstract::AbstractError;
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractIndexedSlices;
|
||||
using mindspore::abstract::AbstractJTagged;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
|
@ -93,7 +94,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
|
||||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
|
||||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
|
||||
ptrBase->isa<abstract::AbstractRefKey>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
max_device_memory_ = kDefaultMaxDeviceMemory;
|
||||
print_file_path_ = "";
|
||||
enable_graph_kernel_ = false;
|
||||
enable_sparse_flag_ = false;
|
||||
}
|
||||
|
||||
std::shared_ptr<MsContext> MsContext::GetInstance() {
|
||||
|
|
|
@ -161,6 +161,9 @@ class MsContext {
|
|||
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
|
||||
bool enable_graph_kernel() const { return enable_graph_kernel_; }
|
||||
|
||||
bool enable_sparse_flag() const { return enable_sparse_flag_; }
|
||||
void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; }
|
||||
|
||||
private:
|
||||
MsContext(const std::string &backend_policy, const std::string &target);
|
||||
void GetGeOptions(std::map<std::string, std::string> *ge_options) const;
|
||||
|
@ -204,6 +207,7 @@ class MsContext {
|
|||
float max_device_memory_;
|
||||
std::string print_file_path_;
|
||||
bool enable_graph_kernel_;
|
||||
bool enable_sparse_flag_;
|
||||
};
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,10 +17,10 @@ from . import dtype
|
|||
from .api import ms_function
|
||||
from .dtype import *
|
||||
from .parameter import Parameter, ParameterTuple
|
||||
from .tensor import MetaTensor, Tensor
|
||||
from .tensor import MetaTensor, Tensor, IndexedSlices
|
||||
|
||||
__all__ = [
|
||||
"MetaTensor", "Tensor", # tensor
|
||||
"MetaTensor", "Tensor", "IndexedSlices", # tensor
|
||||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype"
|
||||
|
|
|
@ -52,13 +52,16 @@ class Parameter:
|
|||
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
|
||||
broadcast and gradients communication would not be applied on parameters. Default: False.
|
||||
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
|
||||
has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false.
|
||||
"""
|
||||
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=""):
|
||||
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False,
|
||||
sparse_grad="", has_indexed_slices_grad=False):
|
||||
self.set_parameter_data(default_input)
|
||||
self.name = name
|
||||
self.requires_grad = requires_grad
|
||||
self.layerwise_parallel = layerwise_parallel
|
||||
self.sparse_grad = sparse_grad
|
||||
self.has_indexed_slices_grad = has_indexed_slices_grad
|
||||
self._is_init = False
|
||||
self._sliced = False
|
||||
self.clone_info = _CloneInfo()
|
||||
|
@ -186,6 +189,17 @@ class Parameter:
|
|||
raise TypeError("`sparse_grad` parameter must be str type")
|
||||
self._sparse_grad = value
|
||||
|
||||
@property
|
||||
def has_indexed_slices_grad(self):
|
||||
"""Return whether the parameter's gradient is indexed_slices."""
|
||||
return self._has_indexed_slices_grad
|
||||
|
||||
@has_indexed_slices_grad.setter
|
||||
def has_indexed_slices_grad(self, value=False):
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("`has_indexed_slices_grad` parameter must be bool type")
|
||||
self._has_indexed_slices_grad = value
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.default_input
|
||||
|
|
|
@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
|
|||
from . import dtype as mstype
|
||||
from ._register_for_tensor import tensor_operator_registry
|
||||
|
||||
__all__ = ['Tensor', 'MetaTensor']
|
||||
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||
np.float32, np.float64, np.bool_)
|
||||
|
@ -214,3 +214,8 @@ class Tensor(Tensor_):
|
|||
raise TypeError("init_flag must be bool.")
|
||||
self.set_init_flag(value)
|
||||
self._init_flag = value
|
||||
|
||||
|
||||
class IndexedSlices:
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -355,6 +355,14 @@ class _Context:
|
|||
def check_bprop(self, check_bprop_flag):
|
||||
self._context_handle.set_check_bprop_flag(check_bprop_flag)
|
||||
|
||||
@property
|
||||
def enable_sparse(self):
|
||||
return self._context_handle.get_enable_sparse_flag()
|
||||
|
||||
@enable_sparse.setter
|
||||
def enable_sparse(self, enable_sparse_flag):
|
||||
self._context_handle.set_enable_sparse_flag(enable_sparse_flag)
|
||||
|
||||
@property
|
||||
def max_device_memory(self):
|
||||
return self._context_handle.get_max_device_memory()
|
||||
|
@ -510,7 +518,8 @@ def reset_auto_parallel_context():
|
|||
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
|
||||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str)
|
||||
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
|
||||
enable_sparse=bool)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Sets context for running environment.
|
||||
|
@ -567,6 +576,7 @@ def set_context(**kwargs):
|
|||
The format is "xxGB". Default: "1024GB".
|
||||
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
|
||||
a file by default, and turn off printing to the screen.
|
||||
enable_sparse (bool): Whether to enable sparse feature. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
|
|
@ -153,6 +153,14 @@ shape_mul = Primitive("shape_mul")
|
|||
# a primitive to compare between tuple.
|
||||
stop_gradient = Primitive("stop_gradient")
|
||||
|
||||
|
||||
make_indexed_slices = Primitive('MakeIndexedSlices')
|
||||
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||
is_indexed_slices = Primitive('IsIndexedSlices')
|
||||
|
||||
|
||||
tensor_operator_registry.register('__add__', tensor_add)
|
||||
tensor_operator_registry.register('__sub__', tensor_sub)
|
||||
tensor_operator_registry.register('__mul__', tensor_mul)
|
||||
|
|
|
@ -564,7 +564,7 @@ class SparseGatherV2(GatherV2):
|
|||
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
|
||||
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
|
||||
>>> axis = 1
|
||||
>>> out = P.GatherV2()(input_params, input_indices, axis)
|
||||
>>> out = P.SparseGatherV2()(input_params, input_indices, axis)
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -603,5 +603,18 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
|
|||
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_indexed_slices) {
|
||||
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_indices");
|
||||
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_indices");
|
||||
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_values");
|
||||
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_values");
|
||||
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_dense_shape");
|
||||
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_dense_shape");
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.indexed_slices_eliminate_});
|
||||
ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1130,3 +1130,38 @@ def test_adjust_allreduce_mul_add(tag):
|
|||
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_indexed_slices(tag):
|
||||
""" test_add_zero """
|
||||
fns = FnDict()
|
||||
make_indexed_slices = Primitive('MakeIndexedSlices')
|
||||
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||
|
||||
@fns
|
||||
def before_get_indices(x, y, z):
|
||||
return indexed_slices_get_indices(make_indexed_slices(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_indices(x, y, z):
|
||||
return x
|
||||
|
||||
@fns
|
||||
def before_get_values(x, y, z):
|
||||
return indexed_slices_get_values(make_indexed_slices(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_values(x, y, z):
|
||||
return y
|
||||
|
||||
@fns
|
||||
def before_get_dense_shape(x, y, z):
|
||||
return indexed_slices_get_dense_shape(make_indexed_slices(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_dense_shape(x, y, z):
|
||||
return z
|
||||
|
||||
return fns[tag]
|
||||
|
|
|
@ -0,0 +1,290 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
@File : test_indexed_slices.py
|
||||
@Author:
|
||||
@Date : 2020-06-08
|
||||
@Desc : test mindspore indexed_slices's operation
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore import Tensor, IndexedSlices, context
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn import Optimizer
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
|
||||
reduce_sum = P.ReduceSum()
|
||||
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||
transpose = P.Transpose()
|
||||
shape_op = P.Shape()
|
||||
reshape = P.Reshape()
|
||||
size_op = P.Size()
|
||||
invert_permutation = P.InvertPermutation()
|
||||
logical_and = P.LogicalAnd()
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
|
||||
@constexpr
|
||||
def _generate_shape_index(out_shape, indices_shape, axis):
|
||||
out_rank = len(out_shape)
|
||||
ind_rank = len(indices_shape)
|
||||
if axis < 0:
|
||||
axis += out_rank - ind_rank + 1
|
||||
perm_part1 = tuple(range(axis, axis + ind_rank))
|
||||
index = tuple(range(out_rank))
|
||||
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
|
||||
return perm
|
||||
|
||||
@constexpr
|
||||
def _generate_inverse_index(x_shape, axis):
|
||||
x_rank = len(x_shape)
|
||||
index = tuple(range(x_rank))
|
||||
if axis < 0:
|
||||
axis += x_rank
|
||||
perm = index[1:1 + axis] + (0,) + index[1 + axis:]
|
||||
return perm
|
||||
|
||||
class MySparseGatherV2(P.GatherV2):
|
||||
"""
|
||||
For test
|
||||
"""
|
||||
|
||||
@bprop_getters.register(MySparseGatherV2)
|
||||
def get_bprop_sparse_gather_v2(self):
|
||||
"""Generate bprop for MySparseGatherV2"""
|
||||
|
||||
def bprop(x, indices, axis, out, dout):
|
||||
x_shp = shape_op(x)
|
||||
if axis == 0:
|
||||
indices_size = (size_op(indices),)
|
||||
x_tail_shp = x_shp[1:]
|
||||
values_shape = indices_size + x_tail_shp
|
||||
values = reshape(dout, values_shape)
|
||||
indices = reshape(indices, indices_size)
|
||||
return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
|
||||
if F.rank(dout) == 0:
|
||||
dout = P.ExpandDims()(dout, -1)
|
||||
if F.rank(indices) == 0:
|
||||
indices = P.ExpandDims()(indices, -1)
|
||||
out_shp = shape_op(dout)
|
||||
ind_shp = shape_op(indices)
|
||||
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
||||
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
|
||||
values_transpose = transpose(dout, perm_1)
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
|
||||
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
||||
perm_2 = _generate_inverse_index(x_shp, axis)
|
||||
params_grad = transpose(params_grad, perm_2)
|
||||
return params_grad, zeros_like(indices), zeros_like(axis)
|
||||
|
||||
return bprop
|
||||
|
||||
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
|
||||
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Undetermined", "Bool")
|
||||
def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||
if gradient.is_indexed_slices():
|
||||
return gradient.values()
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||
- beta2, op_square(gradient_fp32))
|
||||
|
||||
update = next_m / (op_sqrt(next_v) + eps)
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(param, next_param))
|
||||
next_v = F.depend(next_v, F.assign(m, next_m))
|
||||
next_v = F.depend(next_v, F.assign(v, next_v))
|
||||
return next_v
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
|
||||
|
||||
class AdamWeightDecaySparse(Optimizer):
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
|
||||
if self.is_group:
|
||||
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
|
||||
|
||||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||
self.map = C.Map()
|
||||
|
||||
def construct(self, gradients):
|
||||
lr = self.get_lr()
|
||||
updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
return updated_velocity
|
||||
|
||||
|
||||
def test_indexed_slices_make_indexed_slices():
|
||||
class MakeIndexedSlices(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MakeIndexedSlices, self).__init__()
|
||||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
ret = (IndexedSlices(indices, values, self.dense_shape),)
|
||||
return ret[0].is_indexed_slices()
|
||||
indices = Tensor([[0, 0], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
MakeIndexedSlices()(indices, values)
|
||||
|
||||
|
||||
def test_indexed_slices_attr():
|
||||
class IndexedSlicesGetAttr(nn.Cell):
|
||||
def __init__(self):
|
||||
super(IndexedSlicesGetAttr, self).__init__()
|
||||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
x = IndexedSlices(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
indices = Tensor([[0, 0], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
IndexedSlicesGetAttr()(indices, values)
|
||||
|
||||
|
||||
def test_indexed_slices_sparse_gatherv2_grad_all():
|
||||
grad_all = C.GradOperation('get_all', get_all=True)
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
def construct(self, x, y):
|
||||
grad = grad_all(self.network)(x, y)
|
||||
return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices()
|
||||
class SparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseGatherV2, self).__init__()
|
||||
self.sparse_gatherv2 = MySparseGatherV2()
|
||||
self.axis = 0
|
||||
def construct(self, params, indices):
|
||||
return self.sparse_gatherv2(params, indices, self.axis)
|
||||
params = Tensor(np.ones([3, 1, 2]).astype(np.int32))
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
GradWrap(SparseGatherV2())(params, indices)
|
||||
|
||||
|
||||
def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
||||
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
|
||||
def construct(self, x):
|
||||
weights = self.weights
|
||||
grad = grad_by_list(self.network, weights)(x)
|
||||
x = grad[0]
|
||||
return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape()
|
||||
class SparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseGatherV2, self).__init__()
|
||||
self.sparse_gatherv2 = MySparseGatherV2()
|
||||
self.axis = 0
|
||||
self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)),
|
||||
name="params", has_indexed_slices_grad=True)
|
||||
def construct(self, indices):
|
||||
return self.sparse_gatherv2(self.params, indices, self.axis)
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
network = GradWrap(SparseGatherV2())
|
||||
network(indices)
|
||||
|
||||
|
||||
def test_indexed_slices_is_indexed_slices():
|
||||
class MakeIndexedSlices(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MakeIndexedSlices, self).__init__()
|
||||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
indexed_slices = IndexedSlices(indices, values, self.dense_shape)
|
||||
ret = indexed_slices.is_indexed_slices()
|
||||
return ret
|
||||
indices = Tensor([[0, 0], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
MakeIndexedSlices()(indices, values)
|
||||
|
||||
|
||||
def test_indexed_slices_env_get():
|
||||
class Loss(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Loss, self).__init__()
|
||||
def construct(self, base, target):
|
||||
return base
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True)
|
||||
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
|
||||
self.gatherv2 = MySparseGatherV2()
|
||||
self.axis = 0
|
||||
def construct(self, indices):
|
||||
return self.gatherv2(self.w1, indices, self.axis) * self.w2
|
||||
|
||||
inputs = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
|
||||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
loss = Loss()
|
||||
optimizer = AdamWeightDecaySparse(net.trainable_params())
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
train_network(inputs, label)
|
|
@ -155,7 +155,7 @@ def test_AdamWeightDecaySparse():
|
|||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1")
|
||||
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2", sparse_grad="sparse_key_w2")
|
||||
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
|
||||
self.gatherv2 = P.SparseGatherV2()
|
||||
self.axis = 0
|
||||
def construct(self, indices):
|
||||
|
|
Loading…
Reference in New Issue