forked from mindspore-Ecosystem/mindspore
fix cell bprop
This commit is contained in:
parent
76668bef16
commit
34e50e5d6e
|
@ -17,7 +17,7 @@
|
|||
"""Resources for ast tree parse."""
|
||||
import ast
|
||||
import math
|
||||
from mindspore import IndexedSlices, SparseTensor
|
||||
from mindspore import RowTensor, SparseTensor
|
||||
from mindspore.ops.composite import multitype_ops
|
||||
from mindspore.ops import functional as F, composite as C
|
||||
from . import standard_method as M
|
||||
|
@ -140,6 +140,6 @@ convert_object_map = {
|
|||
math.tan: NO_IMPLEMENT,
|
||||
|
||||
# user defined
|
||||
IndexedSlices: F.make_indexed_slices,
|
||||
RowTensor: F.make_row_tensor,
|
||||
SparseTensor: F.make_sparse_tensor,
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
|
|||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
|
||||
}
|
||||
}
|
||||
} else if (type->isa<IndexedSlicesType>()) {
|
||||
} else if (type->isa<RowTensorType>()) {
|
||||
// Do Nothing
|
||||
} else if (type->isa<UndeterminedType>()) {
|
||||
// Do Nothing
|
||||
|
|
|
@ -174,12 +174,11 @@ inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_Virtua
|
|||
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
|
||||
// IndexedSlices
|
||||
inline const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices");
|
||||
inline const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
|
||||
inline const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
|
||||
inline const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
|
||||
inline const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
|
||||
// RowTensor
|
||||
inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor");
|
||||
inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues");
|
||||
inline const PrimitivePtr kPrimRowTensorGetIndices = std::make_shared<Primitive>("RowTensorGetIndices");
|
||||
inline const PrimitivePtr kPrimRowTensorGetDenseShape = std::make_shared<Primitive>("RowTensorGetDenseShape");
|
||||
|
||||
// SparseTensor
|
||||
inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
|
||||
|
|
|
@ -340,8 +340,8 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
|
|||
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
|
@ -393,41 +393,41 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
|
|||
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
|
||||
}
|
||||
}
|
||||
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
||||
auto ret = std::make_shared<AbstractRowTensor>(values->element()->BuildType(), dense_shape_vec);
|
||||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(dense_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(row_tensor->values());
|
||||
return row_tensor->values();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(row_tensor->indices());
|
||||
return row_tensor->indices();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(indexed_slices->values());
|
||||
return indexed_slices->values();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(indexed_slices->indices());
|
||||
return indexed_slices->indices();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape());
|
||||
return indexed_slices->dense_shape();
|
||||
auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
|
||||
return row_tensor->dense_shape();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -32,9 +32,9 @@ namespace opt {
|
|||
using mindspore::abstract::AbstractAttribute;
|
||||
using mindspore::abstract::AbstractClass;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractIndexedSlices;
|
||||
using mindspore::abstract::AbstractJTagged;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractRowTensor;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractSparseTensor;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
|
@ -81,10 +81,10 @@ static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
|
|||
return std::make_shared<AbstractTuple>(abstract_list);
|
||||
}
|
||||
|
||||
if (t->isa<AbstractIndexedSlices>()) {
|
||||
auto abs_indexed_slices = dyn_cast<AbstractIndexedSlices>(t);
|
||||
std::vector<AbstractBasePtr> abstract_list{abs_indexed_slices->indices(), abs_indexed_slices->values(),
|
||||
abs_indexed_slices->dense_shape()};
|
||||
if (t->isa<AbstractRowTensor>()) {
|
||||
auto abs_row_tensor = dyn_cast<AbstractRowTensor>(t);
|
||||
std::vector<AbstractBasePtr> abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(),
|
||||
abs_row_tensor->dense_shape()};
|
||||
return std::make_shared<AbstractTuple>(abstract_list);
|
||||
}
|
||||
|
||||
|
@ -455,16 +455,16 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager
|
|||
} else if (IsValueNode<ValueList>(node)) {
|
||||
new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimMakeIndexedSlices)) {
|
||||
IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) {
|
||||
new_node = ConvertMakeSparseToMakeTuple(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetIndices)) {
|
||||
IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) {
|
||||
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetValues)) {
|
||||
IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) {
|
||||
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetDenseShape)) {
|
||||
IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) {
|
||||
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2);
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
#include "frontend/optimizer/irpass/transpose_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/value_based_eliminate.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -157,10 +157,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
mark_interface_fusion_ =
|
||||
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
|
||||
|
||||
// IndexedSlices Eliminate
|
||||
indexed_slices_eliminate_ = MakeSubstitution(
|
||||
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
|
||||
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
|
||||
// RowTensor Eliminate
|
||||
row_tensor_eliminate_ = MakeSubstitution(
|
||||
std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",
|
||||
{prim::kPrimRowTensorGetIndices, prim::kPrimRowTensorGetValues, prim::kPrimRowTensorGetDenseShape});
|
||||
|
||||
// SparseTensor Eliminate
|
||||
sparse_tensor_eliminate_ = MakeSubstitution(
|
||||
|
|
|
@ -105,8 +105,8 @@ class OptimizeIRPassLib {
|
|||
// Fusion
|
||||
SubstitutionPtr mark_interface_fusion_;
|
||||
|
||||
// IndexedSlices Eliminate
|
||||
SubstitutionPtr indexed_slices_eliminate_;
|
||||
// RowTensor Eliminate
|
||||
SubstitutionPtr row_tensor_eliminate_;
|
||||
|
||||
// SparseTensor Eliminate
|
||||
SubstitutionPtr sparse_tensor_eliminate_;
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
@ -28,24 +28,24 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||
// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||
// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}}
|
||||
class IndexedSlicesEliminater : public AnfVisitor {
|
||||
// {prim::kPrimRowTensorGetIndices, {prim::kPrimMakeRowTensor, Xs}}
|
||||
// {prim::kPrimRowTensorGetValues, {prim::kPrimMakeRowTensor, Xs}}
|
||||
// {prim::kPrimRowTensorGetDenseShape, {prim::kPrimMakeRowTensor, Xs}}
|
||||
class RowTensorEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimRowTensorGetIndices, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(1);
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimRowTensorGetValues, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(2);
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimRowTensorGetDenseShape, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(3);
|
||||
|
@ -54,7 +54,7 @@ class IndexedSlicesEliminater : public AnfVisitor {
|
|||
}
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimMakeRowTensor)) {
|
||||
tuple_ = cnode;
|
||||
is_match_ = true;
|
||||
}
|
||||
|
@ -72,4 +72,4 @@ class IndexedSlicesEliminater : public AnfVisitor {
|
|||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_
|
|
@ -170,7 +170,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.replace_refkey_by_param_,
|
||||
irpass.make_ref_eliminate_,
|
||||
irpass.get_ref_param_eliminate_,
|
||||
irpass.indexed_slices_eliminate_,
|
||||
irpass.row_tensor_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"b_1", b_1},
|
||||
|
|
|
@ -30,153 +30,165 @@ namespace mindspore {
|
|||
namespace pipeline {
|
||||
|
||||
BuiltInTypeMap &GetMethodMap() {
|
||||
static BuiltInTypeMap method_map = {
|
||||
{kObjectTypeString,
|
||||
{
|
||||
{"__bool__", std::string("str_bool")} // C.str_bool
|
||||
}},
|
||||
{kMetaTypeNone,
|
||||
{
|
||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
}},
|
||||
{kNumberTypeBool,
|
||||
{
|
||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||
{"__or__", prim::kPrimBoolOr}, // P.bool_or
|
||||
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
|
||||
{"__ne__", std::string("bool_ne")}, // C.bool_ne
|
||||
{"__bool__", prim::kPrimIdentity} // P.identity
|
||||
}},
|
||||
{kNumberTypeInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
|
||||
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
|
||||
}},
|
||||
{kNumberTypeUInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kNumberTypeFloat,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
|
||||
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
|
||||
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("float_bool")}, // C.float_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kObjectTypeTuple,
|
||||
{
|
||||
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
|
||||
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
|
||||
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
|
||||
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
||||
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
|
||||
}},
|
||||
{kObjectTypeList,
|
||||
{
|
||||
{"__len__", prim::kPrimListLen}, // P.list_len,
|
||||
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
|
||||
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
|
||||
{"__ms_next__", std::string("list_next")}, // C.list_next
|
||||
{"append", std::string("list_append")}, // C.list_next
|
||||
{"__bool__", std::string("list_bool")}, // C.list_bool
|
||||
{"__ms_hasnext__", std::string("list_hasnext")},
|
||||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
||||
}},
|
||||
static BuiltInTypeMap method_map = {{kObjectTypeString,
|
||||
{
|
||||
{"__bool__", std::string("str_bool")} // C.str_bool
|
||||
}},
|
||||
{kMetaTypeNone,
|
||||
{
|
||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
}},
|
||||
{kNumberTypeBool,
|
||||
{
|
||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||
{"__or__", prim::kPrimBoolOr}, // P.bool_or
|
||||
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
|
||||
{"__ne__", std::string("bool_ne")}, // C.bool_ne
|
||||
{"__bool__", prim::kPrimIdentity} // P.identity
|
||||
}},
|
||||
{kNumberTypeInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
|
||||
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
|
||||
}},
|
||||
{kNumberTypeUInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kNumberTypeFloat,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
|
||||
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
|
||||
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("float_bool")}, // C.float_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kObjectTypeTuple,
|
||||
{
|
||||
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
|
||||
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
|
||||
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
|
||||
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
||||
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
|
||||
}},
|
||||
{kObjectTypeList,
|
||||
{
|
||||
{"__len__", prim::kPrimListLen}, // P.list_len,
|
||||
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
|
||||
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
|
||||
{"__ms_next__", std::string("list_next")}, // C.list_next
|
||||
{"append", std::string("list_append")}, // C.list_next
|
||||
{"__bool__", std::string("list_bool")}, // C.list_bool
|
||||
{"__ms_hasnext__", std::string("list_hasnext")},
|
||||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
||||
}},
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"all", std::string("all_")}, // C.reduce_all
|
||||
{"any", std::string("any_")}, // C.reduce_any
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||
{"__mod__", std::string("mod")}, // C.mod
|
||||
{"__pow__", std::string("pow_")}, // C.pow
|
||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||
{"__eq__", std::string("eq")}, // C.eq
|
||||
{"__ne__", std::string("ne")}, // C.ne
|
||||
{"__lt__", std::string("lt")}, // C.lt
|
||||
{"__gt__", std::string("gt")}, // C.gt
|
||||
{"__le__", std::string("le")}, // C.le
|
||||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"__matmul__", prim::kPrimDot}, // P.dot,
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
{kObjectTypeEnvType, {}}};
|
||||
return method_map;
|
||||
}
|
||||
|
||||
BuiltInTypeMap &GetAttrMap() {
|
||||
static BuiltInTypeMap attr_map = {
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"all", std::string("all_")}, // C.reduce_all
|
||||
{"any", std::string("any_")}, // C.reduce_any
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||
{"__mod__", std::string("mod")}, // C.mod
|
||||
{"__pow__", std::string("pow_")}, // C.pow
|
||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||
{"__eq__", std::string("eq")}, // C.eq
|
||||
{"__ne__", std::string("ne")}, // C.ne
|
||||
{"__lt__", std::string("lt")}, // C.lt
|
||||
{"__gt__", std::string("gt")}, // C.gt
|
||||
{"__le__", std::string("le")}, // C.le
|
||||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"__matmul__", prim::kPrimDot}, // P.dot,
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
{"shape", std::string("shape_")}, // C.shape_
|
||||
{"dtype", std::string("dtype_")}, // C.dtype_
|
||||
}},
|
||||
{kObjectTypeIndexedSlicesType,
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
{"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values
|
||||
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
|
||||
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
|
||||
{"values", prim::kPrimRowTensorGetValues}, // F.row_tensor_get_values
|
||||
{"indices", prim::kPrimRowTensorGetIndices}, // F.row_tensor_get_indices
|
||||
{"dense_shape", prim::kPrimRowTensorGetDenseShape}, // F.row_tensor_get_dense_shape
|
||||
}},
|
||||
{kObjectTypeSparseTensorType,
|
||||
{
|
||||
|
@ -184,18 +196,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
|
||||
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
{kObjectTypeEnvType, {}}};
|
||||
return method_map;
|
||||
}
|
||||
|
||||
BuiltInTypeMap &GetAttrMap() {
|
||||
static BuiltInTypeMap attr_map = {{kObjectTypeTensorType,
|
||||
{
|
||||
{"shape", std::string("shape_")}, // C.shape_
|
||||
{"dtype", std::string("dtype_")}, // C.dtype_
|
||||
}}};
|
||||
};
|
||||
return attr_map;
|
||||
}
|
||||
|
||||
|
|
|
@ -132,11 +132,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
||||
// Debug
|
||||
{prim::kPrimDebug, {InferImplDebug, true}},
|
||||
// IndexedSlices
|
||||
{prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}},
|
||||
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
|
||||
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
|
||||
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
|
||||
// RowTensor
|
||||
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}},
|
||||
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}},
|
||||
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}},
|
||||
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
|
||||
// SparseTensor
|
||||
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
|
||||
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
|
||||
|
@ -402,8 +402,8 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
}
|
||||
dic["dtype"] = arg_tensor->BuildType();
|
||||
dic["value"] = BuildValue(arg_tensor->BuildValue());
|
||||
} else if (abs_base->isa<AbstractIndexedSlices>()) {
|
||||
auto arg = dyn_cast<AbstractIndexedSlices>(abs_base);
|
||||
} else if (abs_base->isa<AbstractRowTensor>()) {
|
||||
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
|
||||
dic["shape"] = arg->shape()->shape();
|
||||
dic["dtype"] = arg->BuildType();
|
||||
dic["value"] = BuildValue(arg->BuildValue());
|
||||
|
|
|
@ -348,14 +348,14 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
|
|||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -32,9 +32,9 @@ using mindspore::abstract::AbstractBase;
|
|||
using mindspore::abstract::AbstractClass;
|
||||
using mindspore::abstract::AbstractError;
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractIndexedSlices;
|
||||
using mindspore::abstract::AbstractJTagged;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractRowTensor;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractSparseTensor;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
|
@ -95,7 +95,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
|
||||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
|
||||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() ||
|
||||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -136,8 +136,7 @@ REGISTER_PYBIND_DEFINE(
|
|||
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
|
||||
.def(py::init());
|
||||
(void)py::class_<RowTensorType, Type, std::shared_ptr<RowTensorType>>(m_sub, "RowTensorType").def(py::init());
|
||||
(void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType")
|
||||
.def(py::init());
|
||||
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
|
||||
|
|
|
@ -17,10 +17,10 @@ from . import dtype
|
|||
from .api import ms_function
|
||||
from .dtype import *
|
||||
from .parameter import Parameter, ParameterTuple
|
||||
from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor
|
||||
from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
|
||||
|
||||
__all__ = [
|
||||
"MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor
|
||||
"MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor
|
||||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype"
|
||||
|
|
|
@ -99,7 +99,7 @@ slice_type = typing.Slice
|
|||
ellipsis_type = typing.TypeEllipsis
|
||||
list_type = typing.List
|
||||
tuple_type = typing.Tuple
|
||||
index_slices = typing.IndexedSlicesType()
|
||||
index_slices = typing.RowTensorType()
|
||||
sparse_tensor = typing.SparseTensorType()
|
||||
undetermined = typing.UndeterminedType()
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
|
|||
from . import dtype as mstype
|
||||
from ._register_for_tensor import tensor_operator_registry
|
||||
|
||||
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor']
|
||||
__all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||
np.float32, np.float64, np.bool_)
|
||||
|
@ -267,20 +267,20 @@ class Tensor(Tensor_):
|
|||
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
|
||||
|
||||
|
||||
class IndexedSlices:
|
||||
class RowTensor:
|
||||
"""
|
||||
A sparse representation of a set of tensor slices at given indices.
|
||||
|
||||
An IndexedSlices is typically used to represent a subset of a larger
|
||||
An RowTensor is typically used to represent a subset of a larger
|
||||
tensor dense of shape [L0, D1, .. , DN] where L0 >> D0.
|
||||
|
||||
The values in indices are the indices in the first dimension of the slices
|
||||
that have been extracted from the larger tensor.
|
||||
|
||||
The dense tensor dense represented by an IndexedSlices slices has
|
||||
The dense tensor dense represented by an RowTensor slices has
|
||||
`dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
|
||||
|
||||
IndexedSlices can only be used in the `Cell`'s construct method.
|
||||
RowTensor can only be used in the `Cell`'s contruct method.
|
||||
|
||||
It is not supported in pynative mode at the moment.
|
||||
|
||||
|
@ -291,7 +291,7 @@ class IndexedSlices:
|
|||
of the corresponding dense tensor.
|
||||
|
||||
Returns:
|
||||
IndexedSlices, composed of `indices`, `values`, and `dense_shape`.
|
||||
RowTensor, composed of `indices`, `values`, and `dense_shape`.
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
|
@ -299,8 +299,8 @@ class IndexedSlices:
|
|||
>>> super(Net, self).__init__()
|
||||
>>> self.dense_shape = dense_shape
|
||||
>>> def construct(self, indices, values):
|
||||
>>> x = IndexedSlices(indices, values, self.dense_shape)
|
||||
>>> return x.values(), x.indices(), x.dense_shape()
|
||||
>>> x = RowTensor(indices, values, self.dense_shape)
|
||||
>>> return x.values, x.indices, x.dense_shape
|
||||
>>>
|
||||
>>> indices = Tensor([0])
|
||||
>>> values = Tensor([[1, 2]], dtype=ms.float32)
|
||||
|
@ -308,17 +308,20 @@ class IndexedSlices:
|
|||
"""
|
||||
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
"Init IndexedSlices"
|
||||
"Init RowTensor"
|
||||
self.__indices = indices
|
||||
self.__values = values
|
||||
self.__dense_shape = dense_shape
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
return self.__indices
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return self.__values
|
||||
|
||||
@property
|
||||
def dense_shape(self):
|
||||
return self.__dense_shape
|
||||
|
||||
|
@ -353,7 +356,7 @@ class SparseTensor:
|
|||
>>> self.dense_shape = dense_shape
|
||||
>>> def construct(self, indices, values):
|
||||
>>> x = SparseTensor(indices, values, self.dense_shape)
|
||||
>>> return x.values(), x.indices(), x.dense_shape()
|
||||
>>> return x.values, x.indices, x.dense_shape
|
||||
>>>
|
||||
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
|
@ -366,11 +369,14 @@ class SparseTensor:
|
|||
self.__values = values
|
||||
self.__dense_shape = dense_shape
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
return self.__indices
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return self.__values
|
||||
|
||||
@property
|
||||
def dense_shape(self):
|
||||
return self.__dense_shape
|
||||
|
|
|
@ -1050,16 +1050,16 @@ bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const
|
|||
return AbstractBasePtrListDeepEqual(lhs, rhs);
|
||||
}
|
||||
|
||||
// IndexedSlices
|
||||
TypePtr AbstractIndexedSlices::BuildType() const {
|
||||
// RowTensor
|
||||
TypePtr AbstractRowTensor::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
TypePtr element_type = element()->BuildType();
|
||||
return std::make_shared<IndexedSlicesType>(element_type);
|
||||
return std::make_shared<RowTensorType>(element_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractIndexedSlices::Clone() const {
|
||||
AbstractBasePtr AbstractRowTensor::Clone() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto clone = std::make_shared<AbstractIndexedSlices>(element()->Clone());
|
||||
auto clone = std::make_shared<AbstractRowTensor>(element()->Clone());
|
||||
ShapePtr shp = shape();
|
||||
clone->set_shape(shp->Clone());
|
||||
clone->set_value(GetValueTrack());
|
||||
|
@ -1069,9 +1069,9 @@ AbstractBasePtr AbstractIndexedSlices::Clone() const {
|
|||
return clone;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractIndexedSlices::Broaden() const {
|
||||
AbstractBasePtr AbstractRowTensor::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden());
|
||||
auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
|
||||
auto shp = shape();
|
||||
broaden->set_shape(shp->Clone());
|
||||
broaden->set_value(kAnyValue);
|
||||
|
@ -1081,9 +1081,9 @@ AbstractBasePtr AbstractIndexedSlices::Broaden() const {
|
|||
return broaden;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const {
|
||||
AbstractBasePtr AbstractRowTensor::BroadenWithShape() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden());
|
||||
auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
|
||||
auto shp = shape()->Clone();
|
||||
shp->Broaden();
|
||||
broaden->set_shape(shp);
|
||||
|
@ -1094,7 +1094,7 @@ AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const {
|
|||
return broaden;
|
||||
}
|
||||
|
||||
std::string AbstractIndexedSlices::ToString() const {
|
||||
std::string AbstractRowTensor::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
BaseShapePtr shape_track = GetShapeTrack();
|
||||
MS_EXCEPTION_IF_NULL(shape_track);
|
||||
|
|
|
@ -593,15 +593,15 @@ struct AbstractBasePtrListEqual {
|
|||
std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list);
|
||||
bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
|
||||
|
||||
// IndexedSlices
|
||||
class AbstractIndexedSlices : public AbstractUndetermined {
|
||||
// RowTensor
|
||||
class AbstractRowTensor : public AbstractUndetermined {
|
||||
public:
|
||||
explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
: AbstractUndetermined(element, shape) {}
|
||||
AbstractIndexedSlices(const TypePtr &element_type, const std::vector<int> &shape)
|
||||
AbstractRowTensor(const TypePtr &element_type, const std::vector<int> &shape)
|
||||
: AbstractUndetermined(element_type, shape) {}
|
||||
~AbstractIndexedSlices() override = default;
|
||||
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
|
||||
~AbstractRowTensor() override = default;
|
||||
MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined)
|
||||
|
||||
const AbstractTensorPtr indices() const { return indices_; }
|
||||
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
||||
|
|
|
@ -66,7 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function)
|
|||
ABSTRACT_REPORT_NAME_TRAITS(Type)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(RowTensor)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
|
||||
|
||||
|
|
|
@ -179,40 +179,40 @@ bool TensorType::operator==(const Type &other) const {
|
|||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
|
||||
TypePtr IndexedSlicesType::DeepCopy() const {
|
||||
TypePtr RowTensorType::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(element_type_);
|
||||
if (IsGeneric()) {
|
||||
return std::make_shared<IndexedSlicesType>();
|
||||
return std::make_shared<RowTensorType>();
|
||||
}
|
||||
return std::make_shared<IndexedSlicesType>(element_type_->DeepCopy());
|
||||
return std::make_shared<RowTensorType>(element_type_->DeepCopy());
|
||||
}
|
||||
|
||||
std::string IndexedSlicesType::ToReprString() const {
|
||||
std::string RowTensorType::ToReprString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "IndexedSlices";
|
||||
return "RowTensor";
|
||||
}
|
||||
return "IndexedSlices[" + element_type_->ToReprString() + "]";
|
||||
return "RowTensor[" + element_type_->ToReprString() + "]";
|
||||
}
|
||||
|
||||
std::string IndexedSlicesType::ToString() const {
|
||||
std::string RowTensorType::ToString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "IndexedSlices";
|
||||
return "RowTensor";
|
||||
}
|
||||
return "IndexedSlices[" + element_type_->ToString() + "]";
|
||||
return "RowTensor[" + element_type_->ToString() + "]";
|
||||
}
|
||||
|
||||
std::string IndexedSlicesType::DumpText() const {
|
||||
std::string RowTensorType::DumpText() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "IndexedSlices";
|
||||
return "RowTensor";
|
||||
}
|
||||
return "IndexedSlices[" + element_type_->DumpText() + "]";
|
||||
return "RowTensor[" + element_type_->DumpText() + "]";
|
||||
}
|
||||
|
||||
bool IndexedSlicesType::operator==(const Type &other) const {
|
||||
bool RowTensorType::operator==(const Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_elem_type = static_cast<const IndexedSlicesType &>(other).element_type_;
|
||||
auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_;
|
||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||
return true;
|
||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||
|
|
|
@ -154,15 +154,15 @@ class TensorType : public Object {
|
|||
};
|
||||
using TensorTypePtr = std::shared_ptr<TensorType>;
|
||||
|
||||
class IndexedSlicesType : public Object {
|
||||
class RowTensorType : public Object {
|
||||
public:
|
||||
IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {}
|
||||
explicit IndexedSlicesType(const TypePtr &ele)
|
||||
: Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||
~IndexedSlicesType() override = default;
|
||||
MS_DECLARE_PARENT(IndexedSlicesType, Object)
|
||||
RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
|
||||
explicit RowTensorType(const TypePtr &ele)
|
||||
: Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||
~RowTensorType() override = default;
|
||||
MS_DECLARE_PARENT(RowTensorType, Object)
|
||||
|
||||
TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; }
|
||||
TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
|
||||
const TypePtr element() const { return element_type_; }
|
||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||
|
||||
|
@ -175,7 +175,7 @@ class IndexedSlicesType : public Object {
|
|||
private:
|
||||
TypePtr element_type_;
|
||||
};
|
||||
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
|
||||
using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
|
||||
|
||||
class SparseTensorType : public Object {
|
||||
public:
|
||||
|
|
|
@ -115,8 +115,8 @@ const char *ObjectIdLabel(const TypeId &v) {
|
|||
return "kObjectTypeKeyword";
|
||||
case kObjectTypeTensorType:
|
||||
return "kObjectTypeTensorType";
|
||||
case kObjectTypeIndexedSlicesType:
|
||||
return "kObjectTypeIndexedSlicesType";
|
||||
case kObjectTypeRowTensorType:
|
||||
return "kObjectTypeRowTensorType";
|
||||
case kObjectTypeSparseTensorType:
|
||||
return "kObjectTypeSparseTensorType";
|
||||
case kObjectTypeUndeterminedType:
|
||||
|
|
|
@ -50,7 +50,7 @@ enum TypeId : int {
|
|||
kObjectTypeSlice,
|
||||
kObjectTypeKeyword,
|
||||
kObjectTypeTensorType,
|
||||
kObjectTypeIndexedSlicesType,
|
||||
kObjectTypeRowTensorType,
|
||||
kObjectTypeSparseTensorType,
|
||||
kObjectTypeUndeterminedType,
|
||||
kObjectTypeClass,
|
||||
|
|
|
@ -190,9 +190,9 @@ TypePtr TensorStrToType(const std::string &type_name) {
|
|||
return type;
|
||||
}
|
||||
|
||||
TypePtr IndexedSlicesStrToType(const std::string &type_name) {
|
||||
if (type_name == "IndexedSlices") {
|
||||
return std::make_shared<IndexedSlicesType>();
|
||||
TypePtr RowTensorStrToType(const std::string &type_name) {
|
||||
if (type_name == "RowTensor") {
|
||||
return std::make_shared<RowTensorType>();
|
||||
}
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
|
@ -204,7 +204,7 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) {
|
|||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<IndexedSlicesType>(element_type);
|
||||
return std::make_shared<RowTensorType>(element_type);
|
||||
}
|
||||
|
||||
TypePtr SparseTensorStrToType(const std::string &type_name) {
|
||||
|
@ -364,8 +364,8 @@ TypePtr StringToType(const std::string &type_name) {
|
|||
type = TensorStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) {
|
||||
type = UndeterminedStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
|
||||
type = IndexedSlicesStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("RowTensor"), "RowTensor") == 0) {
|
||||
type = RowTensorStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) {
|
||||
type = SparseTensorStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
||||
|
@ -446,7 +446,7 @@ const TypePtr kTypeExternal = std::make_shared<External>();
|
|||
const TypePtr kTypeEnv = std::make_shared<EnvType>();
|
||||
const TypePtr kTypeType = std::make_shared<TypeType>();
|
||||
const TypePtr kTensorType = std::make_shared<TensorType>();
|
||||
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
|
||||
const TypePtr kRowTensorType = std::make_shared<RowTensorType>();
|
||||
const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>();
|
||||
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
|
||||
const TypePtr kString = std::make_shared<String>();
|
||||
|
|
|
@ -85,13 +85,13 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
"Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr,
|
||||
gradient, params, moment1, moment2, ps_parameter):
|
||||
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
indices = gradient.indices()
|
||||
values = gradient.values()
|
||||
indices = gradient.indices
|
||||
values = gradient.values
|
||||
if ps_parameter:
|
||||
op_shape = P.Shape()
|
||||
shapes = (op_shape(params), op_shape(moment1), op_shape(moment2),
|
||||
|
|
|
@ -24,13 +24,13 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
|
|||
|
||||
|
||||
@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor",
|
||||
"IndexedSlices", "Tensor", "Tensor", "Bool")
|
||||
"RowTensor", "Tensor", "Tensor", "Bool")
|
||||
def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear,
|
||||
gradient, weight, moment, ps_parameter):
|
||||
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
indices = gradient.indices()
|
||||
values = gradient.values()
|
||||
indices = gradient.indices
|
||||
values = gradient.values
|
||||
if ps_parameter:
|
||||
op_shape = P.Shape()
|
||||
shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices))
|
||||
|
|
|
@ -28,13 +28,13 @@ _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
|
|||
|
||||
|
||||
@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor",
|
||||
"IndexedSlices", "Tensor", "Tensor", "Tensor")
|
||||
"RowTensor", "Tensor", "Tensor", "Tensor")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
||||
moment1, moment2):
|
||||
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient.values(), gradient.indices()))
|
||||
eps, gradient.values, gradient.indices))
|
||||
return success
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.nn.cell import Cell
|
|||
from mindspore.nn.layer.container import CellList
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor, IndexedSlices
|
||||
from mindspore.common.tensor import Tensor, RowTensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
@ -493,14 +493,14 @@ op_gather = P.GatherV2()
|
|||
_apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
||||
|
||||
@_apply_decay.register("Number", "Bool", "Tensor", "IndexedSlices")
|
||||
@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor")
|
||||
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
indices = gradient.indices()
|
||||
values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values()))
|
||||
shape = gradient.dense_shape()
|
||||
return IndexedSlices(indices, values, shape)
|
||||
indices = gradient.indices
|
||||
values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values))
|
||||
shape = gradient.dense_shape
|
||||
return RowTensor(indices, values, shape)
|
||||
return gradient
|
||||
|
||||
|
||||
|
@ -523,12 +523,12 @@ def tensor_grad_scale(scale, grad):
|
|||
return grad * scale
|
||||
|
||||
|
||||
@_grad_scale.register("Number", "IndexedSlices")
|
||||
@_grad_scale.register("Number", "RowTensor")
|
||||
def tensor_grad_scale_with_sparse(scale, grad):
|
||||
"""Get grad with scale."""
|
||||
if scale == 1.0:
|
||||
return grad
|
||||
return IndexedSlices(grad.indices(), grad.values() * scale, grad.dense_shape())
|
||||
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)
|
||||
|
||||
|
||||
class _ConvertToCell(LearningRateSchedule):
|
||||
|
|
|
@ -22,12 +22,12 @@ from .optimizer import Optimizer
|
|||
|
||||
_proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
|
||||
|
||||
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor",
|
||||
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor",
|
||||
"Tensor")
|
||||
def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
|
||||
"""Apply sparse proximal_ada_grad optimizer to the weight parameter."""
|
||||
success = True
|
||||
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))
|
||||
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values, gradient.indices))
|
||||
return success
|
||||
|
||||
|
||||
|
|
|
@ -49,6 +49,6 @@ class SparseToDense(Cell):
|
|||
self.sparse_to_dense = P.SparseToDense()
|
||||
|
||||
def construct(self, sparse_tensor):
|
||||
return self.sparse_to_dense(sparse_tensor.indices(),
|
||||
sparse_tensor.values(),
|
||||
sparse_tensor.dense_shape())
|
||||
return self.sparse_to_dense(sparse_tensor.indices,
|
||||
sparse_tensor.values,
|
||||
sparse_tensor.dense_shape)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from mindspore import context
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.communication.management import GlobalComm, get_group_size
|
||||
from mindspore.common.tensor import IndexedSlices
|
||||
from mindspore.common.tensor import RowTensor
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
@ -103,7 +103,7 @@ def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter,
|
|||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices")
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor")
|
||||
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
|
||||
"""
|
||||
Apply allgather on gradient instead of allreduce for sparse feature.
|
||||
|
@ -118,21 +118,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
|
|||
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||
|
||||
Returns:
|
||||
IndexedSlices, the gradient after operation.
|
||||
RowTensor, the gradient after operation.
|
||||
"""
|
||||
if allreduce_filter:
|
||||
indices = allgather(grad.indices())
|
||||
dout = allgather(grad.values())
|
||||
indices = allgather(grad.indices)
|
||||
dout = allgather(grad.values)
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad.values()))
|
||||
degree = F.scalar_cast(degree, F.dtype(grad.values))
|
||||
cast_op = P.Cast()
|
||||
mul_op = P.Mul()
|
||||
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
||||
grad = IndexedSlices(indices, dout, grad.dense_shape())
|
||||
grad = RowTensor(indices, dout, grad.dense_shape)
|
||||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool")
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
|
||||
def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
|
||||
"""
|
||||
Apply allgather on gradient instead of allreduce for sparse feature.
|
||||
|
@ -148,20 +148,20 @@ def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allred
|
|||
ps_parameter (bool): Use parameter server or not.
|
||||
|
||||
Returns:
|
||||
IndexedSlices, the gradient after operation.
|
||||
RowTensor, the gradient after operation.
|
||||
"""
|
||||
if ps_parameter:
|
||||
return grad
|
||||
|
||||
if allreduce_filter:
|
||||
indices = allgather(grad.indices())
|
||||
dout = allgather(grad.values())
|
||||
indices = allgather(grad.indices)
|
||||
dout = allgather(grad.values)
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad.values()))
|
||||
degree = F.scalar_cast(degree, F.dtype(grad.values))
|
||||
cast_op = P.Cast()
|
||||
mul_op = P.Mul()
|
||||
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
||||
grad = IndexedSlices(indices, dout, grad.dense_shape())
|
||||
grad = RowTensor(indices, dout, grad.dense_shape)
|
||||
return grad
|
||||
|
||||
|
||||
|
@ -182,18 +182,18 @@ def _tensors_get_datatype(grad):
|
|||
return F.dtype(grad)
|
||||
|
||||
|
||||
@_get_datatype.register("IndexedSlices")
|
||||
@_get_datatype.register("RowTensor")
|
||||
def _tensors_get_datatype_with_sparse(grad):
|
||||
"""
|
||||
Acquire gradient datatype.
|
||||
|
||||
Args:
|
||||
grad (IndexedSlices): The gradient before operation.
|
||||
grad (RowTensor): The gradient before operation.
|
||||
|
||||
Returns:
|
||||
mstype, the datatype of gradient.
|
||||
"""
|
||||
return F.dtype(grad.values())
|
||||
return F.dtype(grad.values)
|
||||
|
||||
|
||||
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
|
||||
|
@ -214,20 +214,20 @@ def _tensors_cast_datatype(datatype, grad):
|
|||
return F.cast(grad, datatype)
|
||||
|
||||
|
||||
@_cast_datatype.register("TypeType", "IndexedSlices")
|
||||
@_cast_datatype.register("TypeType", "RowTensor")
|
||||
def _tensors_cast_datatype_with_sparse(datatype, grad):
|
||||
"""
|
||||
Cast gradient to datatype.
|
||||
|
||||
Args:
|
||||
datatype (mstype): the destination datatype of gradient.
|
||||
grad (IndexedSlices): The gradient before operation.
|
||||
grad (RowTensor): The gradient before operation.
|
||||
|
||||
Returns:
|
||||
IndexedSlices, the gradient after operation.
|
||||
RowTensor, the gradient after operation.
|
||||
"""
|
||||
dout = F.cast(grad.values(), datatype)
|
||||
return IndexedSlices(grad.indices(), dout, grad.dense_shape())
|
||||
dout = F.cast(grad.values, datatype)
|
||||
return RowTensor(grad.indices, dout, grad.dense_shape)
|
||||
|
||||
|
||||
class DistributedGradReducer(Cell):
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
||||
from ..cell import Cell
|
||||
from ...common import Tensor, IndexedSlices
|
||||
from ...common import Tensor, RowTensor
|
||||
from ...common.parameter import Parameter
|
||||
from ...ops import functional as F
|
||||
from ...ops import composite as C
|
||||
|
@ -35,11 +35,11 @@ reciprocal = P.Reciprocal()
|
|||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
@_grad_scale.register("Tensor", "IndexedSlices")
|
||||
def tensor_grad_scale_indexed_slices(scale, grad):
|
||||
return IndexedSlices(grad.indices(),
|
||||
grad.values() * F.cast(reciprocal(scale), F.dtype(grad.values())),
|
||||
grad.dense_shape())
|
||||
@_grad_scale.register("Tensor", "RowTensor")
|
||||
def tensor_grad_scale_row_tensor(scale, grad):
|
||||
return RowTensor(grad.indices,
|
||||
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
|
||||
grad.dense_shape)
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
|
|
@ -27,7 +27,7 @@ from .grad_base import bprop_getters
|
|||
from ..primitive import constexpr
|
||||
from ... import context
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import IndexedSlices
|
||||
from ...common.tensor import RowTensor
|
||||
|
||||
reduce_sum = P.ReduceSum()
|
||||
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||
|
@ -75,12 +75,12 @@ def dout_cast_number(dout, x):
|
|||
dx = cast(dout, get_dtype(x))
|
||||
return dx
|
||||
|
||||
@dout_cast.register("IndexedSlices", "Tensor")
|
||||
def dout_cast_indexed_slices(dout, x):
|
||||
@dout_cast.register("RowTensor", "Tensor")
|
||||
def dout_cast_row_tensor(dout, x):
|
||||
cast = P.Cast()
|
||||
get_dtype = P.DType()
|
||||
values = cast(dout.values(), get_dtype(x))
|
||||
return IndexedSlices(dout.indices(), values, dout.dense_shape())
|
||||
values = cast(dout.values, get_dtype(x))
|
||||
return RowTensor(dout.indices, values, dout.dense_shape)
|
||||
|
||||
|
||||
@bprop_getters.register(P.Cast)
|
||||
|
@ -240,7 +240,7 @@ def get_bprop_embedding_lookup(self):
|
|||
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
||||
# Reshape the 'actual_dout' on device
|
||||
actual_dout = reshape_op(dout, actual_dout_shape_changed)
|
||||
return IndexedSlices(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
|
||||
return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
|
||||
return bprop_sparse
|
||||
|
||||
|
||||
|
@ -369,7 +369,7 @@ def get_bprop_sparse_gather_v2(self):
|
|||
values_shape = indices_size + x_tail_shp
|
||||
values = reshape(dout, values_shape)
|
||||
indices = reshape(indices, indices_size)
|
||||
return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
|
||||
return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
|
||||
if F.rank(dout) == 0:
|
||||
dout = P.ExpandDims()(dout, -1)
|
||||
if F.rank(indices) == 0:
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from .. import operations as P
|
||||
from ...common.tensor import IndexedSlices
|
||||
from ...common.tensor import RowTensor
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
||||
_GetTensorSlice, _MirrorOperator, ReduceOp,
|
||||
|
@ -47,9 +47,9 @@ def get_bprop_all_reduce(self):
|
|||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
dx = all_reduce_grad(dout)
|
||||
else:
|
||||
indices = all_gather(dout.indices())
|
||||
grad = all_gather(dout.values())
|
||||
dx = IndexedSlices(indices, grad, dout.dense_shape())
|
||||
indices = all_gather(dout.indices)
|
||||
grad = all_gather(dout.values)
|
||||
dx = RowTensor(indices, grad, dout.dense_shape)
|
||||
return (dx,)
|
||||
else:
|
||||
|
||||
|
@ -60,12 +60,12 @@ def get_bprop_all_reduce(self):
|
|||
z = cast(z, dtype(dx))
|
||||
dx = mul(dx, z)
|
||||
else:
|
||||
indices = all_gather(dout.indices())
|
||||
grad = all_gather(dout.values())
|
||||
indices = all_gather(dout.indices)
|
||||
grad = all_gather(dout.values)
|
||||
z = equal(x, out)
|
||||
z = cast(z, dtype(grad))
|
||||
grad = mul(grad, z)
|
||||
dx = IndexedSlices(indices, grad, dout.dense_shape())
|
||||
dx = RowTensor(indices, grad, dout.dense_shape)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
@ -195,19 +195,19 @@ def get_bprop_mirror_operator(self):
|
|||
num = F.scalar_cast(dev_num, F.dtype(dx))
|
||||
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
|
||||
else:
|
||||
indices = all_gather(dout.indices())
|
||||
grad = all_gather(dout.values())
|
||||
indices = all_gather(dout.indices)
|
||||
grad = all_gather(dout.values)
|
||||
float_one = F.scalar_cast(1.0, F.dtype(grad))
|
||||
num = F.scalar_cast(dev_num, F.dtype(grad))
|
||||
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
|
||||
dx = IndexedSlices(indices, grad, dout.dense_shape())
|
||||
dx = RowTensor(indices, grad, dout.dense_shape)
|
||||
else:
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
dx = all_reduce(dout)
|
||||
else:
|
||||
indices = all_gather(dout.indices())
|
||||
grad = all_gather(dout.values())
|
||||
dx = IndexedSlices(indices, grad, dout.dense_shape())
|
||||
indices = all_gather(dout.indices)
|
||||
grad = all_gather(dout.values)
|
||||
dx = RowTensor(indices, grad, dout.dense_shape)
|
||||
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
|
|
@ -152,10 +152,10 @@ shape_mul = Primitive("shape_mul")
|
|||
# a primitive to compare between tuple.
|
||||
stop_gradient = Primitive("stop_gradient")
|
||||
|
||||
make_indexed_slices = Primitive('MakeIndexedSlices')
|
||||
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||
make_row_tensor = Primitive('MakeRowTensor')
|
||||
row_tensor_get_values = Primitive('RowTensorGetValues')
|
||||
row_tensor_get_indices = Primitive('RowTensorGetIndices')
|
||||
row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
|
||||
|
||||
make_sparse_tensor = Primitive('MakeSparseTensor')
|
||||
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
|
||||
|
|
|
@ -389,8 +389,8 @@ class CheckBprop(PrimitiveWithInfer):
|
|||
validator.check_value_type('grads', xshapes, (tuple,), tips)
|
||||
validator.check_value_type('params', yshapes, (tuple,), tips)
|
||||
if len(xshapes) < len(yshapes):
|
||||
raise TypeError(f"{tips}, the size of output should be {len(yshapes)},"
|
||||
f" but got {len(xshapes)}.")
|
||||
raise ValueError(f"{tips}, the size of output should be {len(yshapes)},"
|
||||
f" but got {len(xshapes)}.")
|
||||
checking_range = len(yshapes)
|
||||
for i in range(checking_range):
|
||||
xshape = xshapes[i]
|
||||
|
@ -398,8 +398,8 @@ class CheckBprop(PrimitiveWithInfer):
|
|||
if not xshape or not yshape:
|
||||
continue
|
||||
if xshape != yshape:
|
||||
raise TypeError(f"{tips}, the shape of {i}th output should be {yshape},"
|
||||
f" but got {xshape}.")
|
||||
raise ValueError(f"{tips}, the shape of {i}th output should be {yshape},"
|
||||
f" but got {xshape}.")
|
||||
return xshapes
|
||||
|
||||
def infer_dtype(self, xdtypes, ydtypes):
|
||||
|
@ -407,8 +407,8 @@ class CheckBprop(PrimitiveWithInfer):
|
|||
validator.check_value_type('grads', xdtypes, (tuple,), tips)
|
||||
validator.check_value_type('params', ydtypes, (tuple,), tips)
|
||||
if len(xdtypes) < len(ydtypes):
|
||||
raise TypeError(f"{tips}, the size of output should be {len(ydtypes)},"
|
||||
f" but got {len(xdtypes)}.")
|
||||
raise ValueError(f"{tips}, the size of output should be {len(ydtypes)},"
|
||||
f" but got {len(xdtypes)}.")
|
||||
checking_range = len(ydtypes)
|
||||
for i in range(checking_range):
|
||||
xdtype = xdtypes[i]
|
||||
|
|
|
@ -19,25 +19,16 @@ import pytest
|
|||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
from mindspore import context
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from .....mindspore_test_framework.utils.bprop_util import bprop
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
def teardown_module(module):
|
||||
context.set_context(device_target="Ascend")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class MulAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MulAdd, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return 2 * x + y
|
||||
|
||||
|
@ -45,7 +36,9 @@ class MulAdd(nn.Cell):
|
|||
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
|
||||
return 2 * dout, 2 * y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_mul_add():
|
||||
mul_add = MulAdd()
|
||||
x = Tensor(1, dtype=ms.int32)
|
||||
|
@ -62,7 +55,9 @@ class InlineMulADD(nn.Cell):
|
|||
def construct(self, x, y):
|
||||
return self.mul_add(x, y) + x + self.param * y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_inline_mul_add():
|
||||
inline_mul_add = InlineMulADD()
|
||||
x = Tensor(1, dtype=ms.int32)
|
||||
|
@ -83,7 +78,9 @@ class WithParameter(nn.Cell):
|
|||
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
|
||||
return self.param1 * self.param2 * dout, 2 * y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_param():
|
||||
with_param = WithParameter()
|
||||
with pytest.raises(RuntimeError):
|
||||
|
@ -91,20 +88,21 @@ def test_with_param():
|
|||
|
||||
|
||||
class WithNoBprop(nn.Cell):
|
||||
def __init__(self):
|
||||
super(WithNoBprop, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return 2 * x + y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_no_bprop():
|
||||
with_no_bprop = WithNoBprop()
|
||||
x = Tensor(1, dtype=ms.int32)
|
||||
y = Tensor(2, dtype=ms.int32)
|
||||
C.grad_all(with_no_bprop)(x, y)
|
||||
|
||||
assert C.grad_all(with_no_bprop)(x, y) == (2, 1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_in_bprop_1():
|
||||
class GradInBprop_1(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -140,7 +138,9 @@ def test_grad_in_bprop_1():
|
|||
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
|
||||
assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_in_bprop_2():
|
||||
class GradInBprop_1(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -179,7 +179,9 @@ def test_grad_in_bprop_2():
|
|||
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
|
||||
assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_in_bprop_3():
|
||||
class GradInBprop_1(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -230,7 +232,9 @@ class OneInputBprop(nn.Cell):
|
|||
def bprop(self, x, out, dout):
|
||||
return (5 * x,)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_one_input_bprop():
|
||||
net = OneInputBprop()
|
||||
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
|
@ -239,9 +243,6 @@ def test_grad_one_input_bprop():
|
|||
|
||||
|
||||
class TwoInput(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return x * y
|
||||
|
||||
|
@ -258,12 +259,17 @@ class InlineBpropTwoInput(nn.Cell):
|
|||
grads = C.grad_all(self.f)(x, y)
|
||||
return grads[0] * 2, grads[1] * 2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_inline_bprop_two_input():
|
||||
net = InlineBpropTwoInput()
|
||||
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
input2 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
C.grad_all(net)(input1, input2)
|
||||
grads = C.grad_all(net)(input1, input2)
|
||||
assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
||||
assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
||||
assert len(grads) == 2
|
||||
|
||||
|
||||
class TwoInputBprop(nn.Cell):
|
||||
|
@ -314,7 +320,9 @@ class InlineMutilTwoInputParameterCell(nn.Cell):
|
|||
output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_inline_bprop_multi_input():
|
||||
net = InlineMutilTwoInputParameterCell()
|
||||
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
|
@ -335,29 +343,54 @@ class MulAddWithParam(nn.Cell):
|
|||
def construct(self, x):
|
||||
return self.mul_add(self.param, x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_refkey_bprop():
|
||||
net = MulAddWithParam()
|
||||
grad_by_list = C.GradOperation('get_by_list', get_all=True, get_by_list=True)
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
|
||||
def construct(self, x):
|
||||
weights = self.weights
|
||||
grads = grad_by_list(self.network, weights)(x)
|
||||
return grads
|
||||
network = GradWrap(MulAddWithParam())
|
||||
input_data = Tensor(np.array([2, 2], np.float32))
|
||||
grads = bprop(net, input_data,
|
||||
grads_wrt_outputs=(Tensor(np.ones([1, 2]).astype(np.float32))),
|
||||
wrt=['params', 'inputs'],
|
||||
params=net.trainable_params())
|
||||
grads = network(input_data)
|
||||
assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all()
|
||||
assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
||||
|
||||
|
||||
class MulAddWithWrongOutputType(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MulAddWithWrongOutputType, self).__init__()
|
||||
class MulAddWithWrongOutputNum(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2 * x + y
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
return (2 * dout,)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_mul_add_with_wrong_output_num():
|
||||
context.set_context(check_bprop=True)
|
||||
mul_add = MulAddWithWrongOutputNum()
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(mul_add)(1, 2)
|
||||
|
||||
|
||||
class MulAddWithWrongOutputType(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2 * x + y
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
return 2 * dout, 2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_mul_add_with_wrong_output_type():
|
||||
context.set_context(check_bprop=True)
|
||||
mul_add = MulAddWithWrongOutputType()
|
||||
|
@ -376,7 +409,9 @@ class MulAddWithWrongOutputShape(nn.Cell):
|
|||
def bprop(self, x, y, out, dout):
|
||||
return 2, self.ones
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_mul_add_with_wrong_output_shape():
|
||||
context.set_context(check_bprop=True)
|
||||
mul_add = MulAddWithWrongOutputShape()
|
|
@ -606,14 +606,14 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
|
|||
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_indexed_slices) {
|
||||
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_indices");
|
||||
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_indices");
|
||||
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_values");
|
||||
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_values");
|
||||
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_dense_shape");
|
||||
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_dense_shape");
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.indexed_slices_eliminate_});
|
||||
TEST_F(TestOptLib, test_row_tensor) {
|
||||
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "before_get_indices");
|
||||
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "after_get_indices");
|
||||
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_row_tensor", "before_get_values");
|
||||
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_row_tensor", "after_get_values");
|
||||
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "before_get_dense_shape");
|
||||
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "after_get_dense_shape");
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.row_tensor_eliminate_});
|
||||
ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
|
||||
|
|
|
@ -1130,17 +1130,17 @@ def test_adjust_allreduce_mul_add(tag):
|
|||
return fns[tag]
|
||||
|
||||
|
||||
def test_indexed_slices(tag):
|
||||
def test_row_tensor(tag):
|
||||
""" test_add_zero """
|
||||
fns = FnDict()
|
||||
make_indexed_slices = Primitive('MakeIndexedSlices')
|
||||
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||
make_row_tensor = Primitive('MakeRowTensor')
|
||||
row_tensor_get_values = Primitive('RowTensorGetValues')
|
||||
row_tensor_get_indices = Primitive('RowTensorGetIndices')
|
||||
row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
|
||||
|
||||
@fns
|
||||
def before_get_indices(x, y, z):
|
||||
return indexed_slices_get_indices(make_indexed_slices(x, y, z))
|
||||
return row_tensor_get_indices(make_row_tensor(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_indices(x, y, z):
|
||||
|
@ -1148,7 +1148,7 @@ def test_indexed_slices(tag):
|
|||
|
||||
@fns
|
||||
def before_get_values(x, y, z):
|
||||
return indexed_slices_get_values(make_indexed_slices(x, y, z))
|
||||
return row_tensor_get_values(make_row_tensor(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_values(x, y, z):
|
||||
|
@ -1156,7 +1156,7 @@ def test_indexed_slices(tag):
|
|||
|
||||
@fns
|
||||
def before_get_dense_shape(x, y, z):
|
||||
return indexed_slices_get_dense_shape(make_indexed_slices(x, y, z))
|
||||
return row_tensor_get_dense_shape(make_row_tensor(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_dense_shape(x, y, z):
|
||||
|
|
|
@ -13,10 +13,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
@File : test_indexed_slices.py
|
||||
@File : test_row_tensor.py
|
||||
@Author:
|
||||
@Date : 2020-06-08
|
||||
@Desc : test mindspore indexed_slices's operation
|
||||
@Desc : test mindspore row_tensor's operation
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -29,7 +29,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore import Tensor, IndexedSlices, context
|
||||
from mindspore import Tensor, RowTensor, context
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
@ -122,7 +122,7 @@ def get_bprop_sparse_gather_v2(self):
|
|||
values_shape = indices_size + x_tail_shp
|
||||
values = reshape(dout, values_shape)
|
||||
indices = reshape(indices, indices_size)
|
||||
return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
|
||||
return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
|
||||
if F.rank(dout) == 0:
|
||||
dout = P.ExpandDims()(dout, -1)
|
||||
if F.rank(indices) == 0:
|
||||
|
@ -142,10 +142,10 @@ def get_bprop_sparse_gather_v2(self):
|
|||
|
||||
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
|
||||
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool")
|
||||
def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param,
|
||||
m, v, gradient, decay_flag):
|
||||
return gradient.values()
|
||||
"Tensor", "Tensor", "Tensor", "RowTensor", "Bool")
|
||||
def _update_run_op_for_map_row_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
|
||||
m, v, gradient, decay_flag):
|
||||
return gradient.values
|
||||
|
||||
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
|
@ -219,35 +219,35 @@ class AdamWeightDecaySparse(Optimizer):
|
|||
return updated_velocity
|
||||
|
||||
|
||||
def test_indexed_slices_make_indexed_slices():
|
||||
class MakeIndexedSlices(nn.Cell):
|
||||
def test_row_tensor_make_row_tensor():
|
||||
class MakeRowTensor(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MakeIndexedSlices, self).__init__()
|
||||
super(MakeRowTensor, self).__init__()
|
||||
self.dense_shape = (3, 2)
|
||||
def construct(self, indices, values):
|
||||
ret = (IndexedSlices(indices, values, self.dense_shape),)
|
||||
ret = (RowTensor(indices, values, self.dense_shape),)
|
||||
return ret[0]
|
||||
indices = Tensor([1, 2])
|
||||
values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
|
||||
MakeIndexedSlices()(indices, values)
|
||||
MakeRowTensor()(indices, values)
|
||||
|
||||
|
||||
class IndexedSlicesGetAttr(nn.Cell):
|
||||
class RowTensorGetAttr(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
super(IndexedSlicesGetAttr, self).__init__()
|
||||
super(RowTensorGetAttr, self).__init__()
|
||||
self.dense_shape = dense_shape
|
||||
def construct(self, indices, values):
|
||||
x = IndexedSlices(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
x = RowTensor(indices, values, self.dense_shape)
|
||||
return x.values, x.indices, x.dense_shape
|
||||
|
||||
|
||||
def test_indexed_slices_attr():
|
||||
def test_row_tensor_attr():
|
||||
indices = Tensor([0])
|
||||
values = Tensor([[1, 2]], dtype=ms.float32)
|
||||
IndexedSlicesGetAttr((3, 2))(indices, values)
|
||||
RowTensorGetAttr((3, 2))(indices, values)
|
||||
|
||||
|
||||
def test_indexed_slices_sparse_gatherv2_grad_all():
|
||||
def test_row_tensor_sparse_gatherv2_grad_all():
|
||||
grad_all = C.GradOperation('get_all', get_all=True)
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
|
@ -255,7 +255,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
|
|||
self.network = network
|
||||
def construct(self, x, y):
|
||||
grad = grad_all(self.network)(x, y)
|
||||
return grad[0].indices(), grad[0].values(), grad[0].dense_shape()
|
||||
return grad[0].indices, grad[0].values, grad[0].dense_shape
|
||||
class SparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseGatherV2, self).__init__()
|
||||
|
@ -268,7 +268,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
|
|||
GradWrap(SparseGatherV2())(params, indices)
|
||||
|
||||
|
||||
def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
||||
def test_row_tensor_sparse_gatherv2_grad_with_pram():
|
||||
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
|
@ -279,7 +279,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
|||
weights = self.weights
|
||||
grad = grad_by_list(self.network, weights)(x)
|
||||
x = grad[0]
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
return x.values, x.indices, x.dense_shape
|
||||
class SparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseGatherV2, self).__init__()
|
||||
|
@ -293,7 +293,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
|||
network(indices)
|
||||
|
||||
|
||||
def test_indexed_slices_env_get():
|
||||
def test_row_tensor_env_get():
|
||||
class Loss(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Loss, self).__init__()
|
||||
|
@ -321,7 +321,7 @@ def test_indexed_slices_env_get():
|
|||
train_network(inputs, label)
|
||||
|
||||
|
||||
def test_indexed_slices_model_train():
|
||||
def test_row_tensor_model_train():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, in_features, out_features):
|
||||
super(Net, self).__init__()
|
||||
|
@ -347,76 +347,76 @@ def test_indexed_slices_model_train():
|
|||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_indexed_slices_values_dim_greater_than_dense_shape_dim():
|
||||
def test_row_tensor_values_dim_greater_than_dense_shape_dim():
|
||||
indices = Tensor(np.array([0, 1], dtype=np.int32))
|
||||
values = Tensor(np.random.randn(2, 4, 5).astype(np.float32))
|
||||
dense_shape = (3, 4)
|
||||
with pytest.raises(TypeError):
|
||||
IndexedSlicesGetAttr(dense_shape)(indices, values)
|
||||
RowTensorGetAttr(dense_shape)(indices, values)
|
||||
|
||||
|
||||
def test_indexed_slices_values_dim_less_than_dense_shape_dim():
|
||||
def test_row_tensor_values_dim_less_than_dense_shape_dim():
|
||||
indices = Tensor(np.array([0, 1], dtype=np.int32))
|
||||
values = Tensor(np.random.randn(2, 4).astype(np.float32))
|
||||
dense_shape = (3, 4, 5)
|
||||
with pytest.raises(TypeError):
|
||||
IndexedSlicesGetAttr(dense_shape)(indices, values)
|
||||
RowTensorGetAttr(dense_shape)(indices, values)
|
||||
|
||||
|
||||
def test_indexed_slices_value_and_dense_shape_illegal():
|
||||
def test_row_tensor_value_and_dense_shape_illegal():
|
||||
indices = Tensor(np.array([0, 1], dtype=np.int32))
|
||||
values = Tensor(np.random.randn(2, 4).astype(np.float32))
|
||||
dense_shape = (3, 5)
|
||||
with pytest.raises(TypeError):
|
||||
IndexedSlicesGetAttr(dense_shape)(indices, values)
|
||||
RowTensorGetAttr(dense_shape)(indices, values)
|
||||
|
||||
|
||||
class IndexedSlicesValuesDouble(nn.Cell):
|
||||
class RowTensorValuesDouble(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices()
|
||||
values = x.values() * 2
|
||||
dense_shape = x.dense_shape()
|
||||
return IndexedSlices(indices, values, dense_shape)
|
||||
indices = x.indices
|
||||
values = x.values * 2
|
||||
dense_shape = x.dense_shape
|
||||
return RowTensor(indices, values, dense_shape)
|
||||
|
||||
|
||||
class IndexedSlicesValuesAdd2(nn.Cell):
|
||||
class RowTensorValuesAdd2(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices()
|
||||
values = x.values() + 2
|
||||
dense_shape = x.dense_shape()
|
||||
return IndexedSlices(indices, values, dense_shape)
|
||||
indices = x.indices
|
||||
values = x.values + 2
|
||||
dense_shape = x.dense_shape
|
||||
return RowTensor(indices, values, dense_shape)
|
||||
|
||||
|
||||
class IndexedSlicesWithControlIf(nn.Cell):
|
||||
class RowTensorWithControlIf(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
super().__init__()
|
||||
self.op1 = IndexedSlicesValuesDouble()
|
||||
self.op2 = IndexedSlicesValuesAdd2()
|
||||
self.op1 = RowTensorValuesDouble()
|
||||
self.op2 = RowTensorValuesAdd2()
|
||||
self.dense_shape = dense_shape
|
||||
|
||||
def construct(self, a, b, indices, values):
|
||||
x = IndexedSlices(indices, values, self.dense_shape)
|
||||
x = RowTensor(indices, values, self.dense_shape)
|
||||
if a > b:
|
||||
x = self.op1(x)
|
||||
else:
|
||||
x = self.op2(x)
|
||||
return x.indices(), x.values()
|
||||
return x.indices, x.values
|
||||
|
||||
|
||||
def test_indexed_slices_with_control_flow_if():
|
||||
def test_row_tensor_with_control_flow_if():
|
||||
a = Tensor(np.array(0).astype(np.int32))
|
||||
b = Tensor(np.array(2).astype(np.int32))
|
||||
indices = Tensor(np.array([0, 2]).astype(np.int32))
|
||||
values = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
dense_shape = (5, 2)
|
||||
|
||||
net = IndexedSlicesWithControlIf(dense_shape)
|
||||
net = RowTensorWithControlIf(dense_shape)
|
||||
net(a, b, indices, values)
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ def test_sparse_tensor_attr():
|
|||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
x = SparseTensor(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
return x.values, x.indices, x.dense_shape
|
||||
|
||||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
|
|
|
@ -175,7 +175,7 @@ def test_bprop_with_wrong_output_num():
|
|||
def construct(self, x, y):
|
||||
return BpropWithWrongOutputNum()(x, y)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
|
||||
|
||||
def test_bprop_with_wrong_output_type():
|
||||
|
@ -247,7 +247,7 @@ def test_bprop_with_wrong_output_shape():
|
|||
def construct(self, x):
|
||||
return BpropWithWrongOutputShape()(x)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
net = BpropWithWrongOutputShapeCell()
|
||||
net.set_grad()
|
||||
C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
"""
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor, IndexedSlices, SparseTensor
|
||||
from mindspore import context, Tensor, RowTensor, SparseTensor
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True)
|
||||
|
@ -36,18 +36,18 @@ class GradWrap(nn.Cell):
|
|||
return grad
|
||||
|
||||
|
||||
def test_indexed_slices_attr():
|
||||
class IndexedSlicesGetAttr(nn.Cell):
|
||||
def test_row_tensor_attr():
|
||||
class RowTensorGetAttr(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
super(IndexedSlicesGetAttr, self).__init__()
|
||||
super(RowTensorGetAttr, self).__init__()
|
||||
self.dense_shape = dense_shape
|
||||
def construct(self, indices, values):
|
||||
x = IndexedSlices(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
x = RowTensor(indices, values, self.dense_shape)
|
||||
return x.values, x.indices, x.dense_shape
|
||||
indices = Tensor([0])
|
||||
values = Tensor([[1, 2]], dtype=ms.float32)
|
||||
IndexedSlicesGetAttr((3, 2))(indices, values)
|
||||
GradWrap(IndexedSlicesGetAttr((3, 2)))(indices, values)
|
||||
RowTensorGetAttr((3, 2))(indices, values)
|
||||
GradWrap(RowTensorGetAttr((3, 2)))(indices, values)
|
||||
|
||||
|
||||
def test_sparse_tensor_attr():
|
||||
|
@ -57,7 +57,7 @@ def test_sparse_tensor_attr():
|
|||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
x = SparseTensor(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
return x.values, x.indices, x.dense_shape
|
||||
|
||||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
|
|
Loading…
Reference in New Issue