Apply indexed_slices

This commit is contained in:
panyifeng 2020-07-02 11:08:54 +08:00
parent e03bd975a9
commit 44e74ad5aa
35 changed files with 198 additions and 356 deletions

View File

@ -45,7 +45,8 @@ FuncGraph::FuncGraph()
hyper_param_count_(0),
is_generated_(false),
return_(nullptr),
manager_(std::weak_ptr<FuncGraphManager>()) {
manager_(std::weak_ptr<FuncGraphManager>()),
stub_(false) {
debug_info_ = std::make_shared<GraphDebugInfo>();
}

View File

@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase {
void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs);
bool HasEffect(const CNodePtr &cnode);
bool stub() const { return stub_; }
void set_stub(bool stub) { stub_ = stub; }
private:
// graph is manipulated by manager and others
friend FuncGraphManager;
@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase {
// CNode order which relates to origin code order
std::list<CNodePtr> order_;
bool stub_;
};
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {

View File

@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
(*target_func_graph)->set_is_generate(func_graph->is_generated());
(*target_func_graph)->set_stub(func_graph->stub());
TraceManager::EndTrace();
}
@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub());
for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
}

View File

@ -30,6 +30,7 @@
#include "pipeline/static_analysis/param_validator.h"
#include "operator/cc_implementations.h"
#include "optimizer/opt.h"
#include "utils/context/ms_context.h"
#include "utils/symbolic.h"
#include "pybind_api/api_register.h"
#include "./common.h"
@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
}
return item.second;
}
// Try best match
py::function py_fn_subclass;
size_t subclass_match_cnt = 0;
for (auto &item : fn_cache_py_) {
TypePtrList sign = item.first;
if (sign.size() != types.size()) {
continue;
return py::none();
}
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
}
auto match = true;
for (size_t i = 0; i < sign.size(); ++i) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) &&
!IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) {
match = false;
break;
}
if (undetermined_param != nullptr) {
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < types.size(); ++i) {
if (types[i]->type_id() == kObjectTypeFunction) {
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
}
}
if (!match) {
continue;
}
py_fn_subclass = item.second;
subclass_match_cnt++;
auto stub_output = stub->NewCNode(inputs);
stub->set_output(stub_output);
stub->set_stub(true);
return stub;
}
if (subclass_match_cnt > 1) {
MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass";
}
if (subclass_match_cnt == 1) {
MS_LOG(DEBUG) << "Found one subclass match";
return py_fn_subclass;
}
return py::none();
return nullptr;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
return func_graph;
}
auto stub = GenerateStubFunc(types);
if (stub != nullptr) {
MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString();
return stub;
}
std::ostringstream oss;
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
<< "`, corresponding location info:\n";

View File

@ -23,8 +23,8 @@
#include "pipeline/static_analysis/param_validator.h"
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/utils.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
#include "utils/symbolic.h"
namespace mindspore {
namespace abstract {
@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return AbstractFunction::MakeAbstractFunction(jv);
}
class UndeterminedShapeType {
public:
explicit UndeterminedShapeType(const std::string &env_str) {
// param_name indices_shape indices_type values_shape values_type dense_shape
// export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1
// 2:Float32:3 1 2"
std::vector<string> fields;
string tmp;
std::stringstream input(env_str);
while (std::getline(input, tmp, ':')) {
fields.push_back(tmp);
}
if (fields.size() != fields_num) {
MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size();
}
param_name_ = fields[0];
indices_shape_ = GetShape(fields[1]);
indices_type_ = StringToType(fields[2]);
values_shape_ = GetShape(fields[3]);
values_type_ = StringToType(fields[4]);
auto dense_shape_vec = GetShape(fields[5]);
AbstractBasePtrList dense_shape_list;
(void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list),
[](const auto &elem) { return FromValue(elem, false); });
dense_shape_ = dense_shape_list;
}
~UndeterminedShapeType() = default;
const std::string &param_name() { return param_name_; }
const std::vector<int> &indices_shape() { return indices_shape_; }
const TypePtr &indices_type() { return indices_type_; }
const std::vector<int> &values_shape() { return values_shape_; }
const TypePtr &values_type() { return values_type_; }
const AbstractBasePtrList &dense_shape() { return dense_shape_; }
private:
std::string param_name_;
std::vector<int> indices_shape_;
TypePtr indices_type_;
std::vector<int> values_shape_;
TypePtr values_type_;
AbstractBasePtrList dense_shape_;
static const size_t fields_num;
std::vector<int> GetShape(const std::string &shape_str);
};
std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) {
std::vector<int> ret;
std::istringstream iss(shape_str);
int elem;
while (iss.good()) {
iss >> elem;
ret.emplace_back(elem);
}
return ret;
}
const size_t UndeterminedShapeType::fields_num = 6;
std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs;
void InitUndeterminedFromEnv(const std::string &sparse_shape_types) {
std::string tmp;
std::stringstream input(sparse_shape_types);
g_undetermined_configs.clear();
while (std::getline(input, tmp, ';')) {
auto config = UndeterminedShapeType(tmp);
g_undetermined_configs.insert(std::make_pair(config.param_name(), config));
MS_LOG(DEBUG) << "Undetermined config from env: " << tmp;
}
}
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
}
if (!key->sparse_grad().empty()) {
// Will be fixed once undetermined type ready
if (g_undetermined_configs.empty()) {
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES");
MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types;
if (sparse_shape_types.empty()) {
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2";
}
InitUndeterminedFromEnv(sparse_shape_types);
}
auto shape_types = g_undetermined_configs.find(key->sparse_grad());
if (shape_types == g_undetermined_configs.end()) {
MS_LOG(EXCEPTION) << "Param " << key->ToString()
<< " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES";
}
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString();
AbstractBasePtrList sparse_list;
// indices
auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.indices_type());
auto indices =
std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types->second.indices_shape()));
sparse_list.emplace_back(indices);
// values
auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.values_type());
auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types->second.values_shape()));
sparse_list.emplace_back(dout);
// dense_shape
sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types->second.dense_shape()));
return std::make_shared<AbstractTuple>(sparse_list);
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) {
bool enable_sparse = context->enable_sparse();
if (enable_sparse && dflt->isa<AbstractTensor>()) {
auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
}
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
return dflt;
}
@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
}
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad());
return ret;
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
}
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,

View File

@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
}
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr;
}
@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor {
// G
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr;
}
// Do not inline GraphKernel to Cell.

View File

@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
std::string env = common::GetEnv("SLICE_ENV");
if (!env.empty()) {
MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env;
abstract::InitUndeterminedFromEnv(env);
}
}

View File

@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
ValuePtr value = param_value->value();
constexpr bool broaden = true;
AbstractBasePtr ptr = abstract::FromValue(value, broaden);
ptr->set_sparse_grad(param_value->sparse_grad());
ptr->set_has_indexed_slices_grad(param_value->has_indexed_slices_grad());
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr);

View File

@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
"Set the GraphKernel switch to on or off.")
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
.def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.")
.def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse.");
.def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.")
.def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

View File

@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
return true;
}
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},
std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup},
{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_b", OptPassBGroup},
{"cconv", CconvPass},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup},
{"add_control_depend", AddControlDependPass},
{"opt_control", ControlGroup},
{"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};
std::vector<PassItem> kGePasses = {
{"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass},
{"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
} // namespace pipeline

View File

@ -146,37 +146,35 @@ MethodMap &GetMethodMap() {
}},
{kObjectTypeTensorType,
{
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul
{"__truediv__", std::string("truediv")}, // C.truediv
{"__floordiv__", std::string("floordiv")}, // C.floordiv
{"__mod__", std::string("mod")}, // C.mod
{"__pow__", std::string("pow_")}, // C.pow
{"__floor__", std::string("array_floor")}, // C.array_floor
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
{"__pos__", std::string("array_uadd")}, // C.array_uadd
{"__neg__", std::string("array_usub")}, // C.array_usub
{"__eq__", std::string("eq")}, // C.eq
{"__ne__", std::string("ne")}, // C.ne
{"__lt__", std::string("lt")}, // C.lt
{"__gt__", std::string("gt")}, // C.gt
{"__le__", std::string("le")}, // C.le
{"__ge__", std::string("ge")}, // C.ge
{"__matmul__", prim::kPrimDot}, // P.dot,
{"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
{"transpose", std::string("transpose")}, // P.transpose
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul
{"__truediv__", std::string("truediv")}, // C.truediv
{"__floordiv__", std::string("floordiv")}, // C.floordiv
{"__mod__", std::string("mod")}, // C.mod
{"__pow__", std::string("pow_")}, // C.pow
{"__floor__", std::string("array_floor")}, // C.array_floor
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
{"__pos__", std::string("array_uadd")}, // C.array_uadd
{"__neg__", std::string("array_usub")}, // C.array_usub
{"__eq__", std::string("eq")}, // C.eq
{"__ne__", std::string("ne")}, // C.ne
{"__lt__", std::string("lt")}, // C.lt
{"__gt__", std::string("gt")}, // C.gt
{"__le__", std::string("le")}, // C.le
{"__ge__", std::string("ge")}, // C.ge
{"__matmul__", prim::kPrimDot}, // P.dot,
{"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
{"transpose", std::string("transpose")}, // P.transpose
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
}},
{kObjectTypeIndexedSlicesType,
{
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
{"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape

View File

@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const {
AbstractBasePtr AbstractBase::Broaden() const {
AbstractBasePtr clone = Clone();
clone->set_value(kAnyValue);
clone->set_sparse_grad(sparse_grad_);
return clone;
}
@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const {
MS_EXCEPTION_IF_NULL(type_);
MS_EXCEPTION_IF_NULL(shape_);
buffer << type_name() << "("
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
<< " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")";
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")";
return buffer.str();
}
@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
if (*this == *other) {
auto ret = shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
return shared_from_base<AbstractBase>();
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
if (res_value == value_self) {
auto ret = shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
return shared_from_base<AbstractBase>();
}
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
return std::make_shared<AbstractScalar>(res_value, res_type);
}
AbstractBasePtr AbstractType::Clone() const {
@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
if (sparse_grad() == other->sparse_grad()) {
return shared_from_base<AbstractBase>();
}
return shared_from_base<AbstractBase>();
}
auto element = element_->Join(other_tensor->element_);
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractTensor>(element, shape);
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
return std::make_shared<AbstractTensor>(element, shape);
}
bool AbstractTensor::operator==(const AbstractTensor &other) const {
@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const {
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_sparse_grad(sparse_grad());
clone->set_has_indexed_slices_grad(has_indexed_slices_grad());
return clone;
}
@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const {
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
return broaden;
}
@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
return broaden;
}
@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const {
MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
<< " has_indexed_slices_grad " << has_indexed_slices_grad() << ")";
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
return buffer.str();
}

View File

@ -44,7 +44,7 @@ class AbstractBase : public Base {
public:
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
const BaseShapePtr &shape = kNoShape)
: value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {}
: value_(value), type_(type), shape_(shape) {}
~AbstractBase() override = default;
MS_DECLARE_PARENT(AbstractBase, Base)
@ -53,17 +53,11 @@ class AbstractBase : public Base {
virtual bool operator==(const AbstractBase &other) const;
void set_value(const ValuePtr &value) { value_ = value; }
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) {
has_indexed_slices_grad_ = has_indexed_slices_grad;
}
void set_type(const TypePtr &type) { type_ = type; }
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
const std::string &value_desc() const { return value_desc_; }
ValuePtr GetValueTrack() const { return value_; }
const std::string &sparse_grad() const { return sparse_grad_; }
const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
TypePtr GetTypeTrack() const { return type_; }
BaseShapePtr GetShapeTrack() const { return shape_; }
@ -91,8 +85,6 @@ class AbstractBase : public Base {
TypePtr type_;
BaseShapePtr shape_;
std::string value_desc_; // store initial value description for error report
std::string sparse_grad_;
bool has_indexed_slices_grad_;
};
class AbstractScalar : public AbstractBase {

View File

@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
}
MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString();
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
<< ", is stub: " << fg->stub();
if (fg->stub()) {
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
}
return std::make_shared<EvalResult>(ret_base, nullptr);
}

View File

@ -25,6 +25,7 @@
#include <vector>
#include "pipeline/static_analysis/static_analysis.h"
#include "utils/context/ms_context.h"
namespace mindspore {
namespace abstract {
@ -59,6 +60,13 @@ class Evaluator : public Base {
}
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
return true;

View File

@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
using mindspore::parse::PyObjectWrapper;
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
auto ret_abstract = AbstractEval(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
return ret_abstract;
}
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
}
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
} // end anonymous namespace
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
return ret_abstract;
}
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
return ret_abstract;
}
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
}
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
return ret_abstract;
}
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
return ret_abstract;
}
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
if (nargs_ != args.size()) {
@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref();
MS_EXCEPTION_IF_NULL(ref_value);
ret->set_sparse_grad(ref_value->sparse_grad());
ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad());
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
}
@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x = SensitivityTransform(x);
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
abs_scalar->set_sparse_grad(x->sparse_grad());
abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad());
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
}
};
@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag) {
auto ret_abstract = AbstractEval(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
return ret_abstract;
}
auto ret_abstract = AbstractEval(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
return ret_abstract;
}
// Inputs: data, item
if (args_spec_list.size() != 2) {

View File

@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
void InitUndeterminedFromEnv(const std::string &sparse_shape_types);
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

View File

@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
AbstractFunctionPtr func = real_a->GetUnique();
SpecializeStatusCode errcode;
ScopeGuard scope_guard(node->scope());
AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode);
AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode);
if (repl == nullptr) {
if (errcode == kSpecializeFindUniqueArgvalDead) {
const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
return repl;
}
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractFunctionPtr &func,
const AbstractBasePtrList &args,
SpecializeStatusCode *errcode) {
MS_EXCEPTION_IF_NULL(abs);
@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
<< ", graph: " << context->func_graph()->get_return()->DebugString();
if (context->func_graph()->stub()) {
MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
<< ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
<< ", " << node->ToString();
return node;
}
FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
v->set_flag(kFuncGraphFlagUndetermined, false);
return BuildValueNode(v, abs);
}
@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
return kSpecializeSuccess;
} else if (choices->empty()) {
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
<< func->type_name();
return kSpecializeFindUniqueArgvalDead;
} else {
if (IsPolyFunc(func, argvals)) {

View File

@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
// Build a specialized node from given argvals;
AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractBasePtrList &argvals);
AnfNodePtr BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
const AbstractBasePtrList &args, SpecializeStatusCode *errcode);
AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractFunctionPtr &func, const AbstractBasePtrList &args,
SpecializeStatusCode *errcode);
// Find the unique argument values which can be used to specialize a primitive or graph function.
SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval,

View File

@ -89,7 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
max_device_memory_ = kDefaultMaxDeviceMemory;
print_file_path_ = "";
enable_graph_kernel_ = false;
enable_sparse_flag_ = false;
enable_sparse_ = false;
}
std::shared_ptr<MsContext> MsContext::GetInstance() {

View File

@ -161,8 +161,8 @@ class MsContext {
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
bool enable_graph_kernel() const { return enable_graph_kernel_; }
bool enable_sparse_flag() const { return enable_sparse_flag_; }
void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; }
bool enable_sparse() const { return enable_sparse_; }
void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; }
private:
MsContext(const std::string &backend_policy, const std::string &target);
@ -207,7 +207,7 @@ class MsContext {
float max_device_memory_;
std::string print_file_path_;
bool enable_graph_kernel_;
bool enable_sparse_flag_;
bool enable_sparse_;
};
} // namespace mindspore

View File

@ -51,18 +51,13 @@ class Parameter:
requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
broadcast and gradients communication would not be applied on parameters. Default: False.
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false.
"""
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False,
sparse_grad="", has_indexed_slices_grad=False):
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
self._value = ParamValue()
self.set_parameter_data(default_input)
self.name = name
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
self.sparse_grad = sparse_grad
self.has_indexed_slices_grad = has_indexed_slices_grad
self._is_init = False
self._sliced = False
if context.get_context("mode") == context.PYNATIVE_MODE:
@ -177,28 +172,6 @@ class Parameter:
raise TypeError("`requires_grad` parameter must be bool type")
self._value.requires_grad = value
@property
def sparse_grad(self):
"""Return whether the parameter's gradient is sparse."""
return self._value.sparse_grad
@sparse_grad.setter
def sparse_grad(self, value=""):
if not isinstance(value, str):
raise TypeError("`sparse_grad` parameter must be str type")
self._value.sparse_grad = value
@property
def has_indexed_slices_grad(self):
"""Return whether the parameter's gradient is indexed_slices."""
return self._value.has_indexed_slices_grad
@has_indexed_slices_grad.setter
def has_indexed_slices_grad(self, value=False):
if not isinstance(value, bool):
raise TypeError("`has_indexed_slices_grad` parameter must be bool type")
self._value.has_indexed_slices_grad = value
@property
def data(self):
return self.default_input

View File

@ -367,14 +367,6 @@ class _Context:
def check_bprop(self, check_bprop_flag):
self._context_handle.set_check_bprop_flag(check_bprop_flag)
@property
def enable_sparse(self):
return self._context_handle.get_enable_sparse_flag()
@enable_sparse.setter
def enable_sparse(self, enable_sparse_flag):
self._context_handle.set_enable_sparse_flag(enable_sparse_flag)
@property
def max_device_memory(self):
return self._context_handle.get_max_device_memory()
@ -408,6 +400,13 @@ class _Context:
full_file_name = print_file_path
self._context_handle.set_print_file_path(full_file_name)
@property
def enable_sparse(self):
return self._context_handle.get_enable_sparse()
@enable_sparse.setter
def enable_sparse(self, enable_sparse):
self._context_handle.set_enable_sparse(enable_sparse)
def check_input_format(x):
import re
@ -601,7 +600,7 @@ def set_context(**kwargs):
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
suffix to the file.
enable_sparse (bool): Whether to enable sparse feature. Default: False.
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
Raises:
ValueError: If input key is not an attribute in context.

View File

@ -162,8 +162,8 @@ class Adam(Optimizer):
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
Args:

View File

@ -72,8 +72,8 @@ class FTRL(Optimizer):
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
Args:

View File

@ -91,8 +91,8 @@ class LazyAdam(Optimizer):
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse behavior, to be notice, is not equivalent to the
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
continuous development. The sparse behavior is currently performed on the CPU.

View File

@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer):
<http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_.
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
Args:

View File

@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices')
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
is_indexed_slices = Primitive('IsIndexedSlices')
tensor_operator_registry.register('__add__', tensor_add)

View File

@ -36,6 +36,8 @@ from mindspore._checkparam import Rel
from mindspore.nn import Optimizer
from mindspore.nn import TrainOneStepCell, WithLossCell
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum()
transpose = P.Transpose()
@ -44,7 +46,6 @@ reshape = P.Reshape()
size_op = P.Size()
invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd()
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
@constexpr
def _generate_shape_index(out_shape, indices_shape, axis):
@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self):
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Undetermined", "Bool")
def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
if gradient.is_indexed_slices():
return gradient.values()
"Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool")
def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param,
m, v, gradient, decay_flag):
return gradient.values()
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
m, v, gradient, decay_flag):
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices():
self.dense_shape = (3, 4)
def construct(self, indices, values):
ret = (IndexedSlices(indices, values, self.dense_shape),)
return ret[0].is_indexed_slices()
return ret[0]
indices = Tensor([[0, 0], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
MakeIndexedSlices()(indices, values)
@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
self.network = network
def construct(self, x, y):
grad = grad_all(self.network)(x, y)
return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices()
return grad, grad[0], grad[1]
class SparseGatherV2(nn.Cell):
def __init__(self):
super(SparseGatherV2, self).__init__()
@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
weights = self.weights
grad = grad_by_list(self.network, weights)(x)
x = grad[0]
return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape()
return x, x.values(), x.indices(), x.dense_shape()
class SparseGatherV2(nn.Cell):
def __init__(self):
super(SparseGatherV2, self).__init__()
self.sparse_gatherv2 = MySparseGatherV2()
self.axis = 0
self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)),
name="params", has_indexed_slices_grad=True)
self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params")
def construct(self, indices):
return self.sparse_gatherv2(self.params, indices, self.axis)
indices = Tensor(np.array([0, 1]).astype(np.int32))
@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
network(indices)
def test_indexed_slices_is_indexed_slices():
class MakeIndexedSlices(nn.Cell):
def __init__(self):
super(MakeIndexedSlices, self).__init__()
self.dense_shape = (3, 4)
def construct(self, indices, values):
indexed_slices = IndexedSlices(indices, values, self.dense_shape)
ret = indexed_slices.is_indexed_slices()
return ret
indices = Tensor([[0, 0], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
MakeIndexedSlices()(indices, values)
def test_indexed_slices_env_get():
class Loss(nn.Cell):
def __init__(self):
@ -271,7 +262,7 @@ def test_indexed_slices_env_get():
class NetWithSparseGatherV2(nn.Cell):
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True)
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
self.gatherv2 = MySparseGatherV2()
self.axis = 0

View File

@ -17,12 +17,13 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class Net(nn.Cell):
""" Net definition """
@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)),
name="weight1", sparse_grad="sparse_key_w1")
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0
self.gather = P.SparseGatherV2()

View File

@ -27,6 +27,7 @@ from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
context.set_context(enable_sparse=True)
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse():
class NetWithSparseGatherV2(nn.Cell):
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1")
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
self.gatherv2 = P.SparseGatherV2()
self.axis = 0

View File

@ -17,12 +17,13 @@
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import FTRL
from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class Net(nn.Cell):
def __init__(self):
@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)),
name="weight1", sparse_grad="sparse_key_w1")
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0
self.gather = P.SparseGatherV2()

View File

@ -17,12 +17,13 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import LazyAdam
from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class Net(nn.Cell):
""" Net definition """
@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)),
name="weight1", sparse_grad="sparse_key_w1")
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0
self.gather = P.SparseGatherV2()

View File

@ -17,12 +17,13 @@
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import ProximalAdagrad
from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class Net(nn.Cell):
def __init__(self):
@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1",
sparse_grad="sparse_key_w1")
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="weight2")
self.axis = 0
self.gather = P.SparseGatherV2()

View File

@ -53,4 +53,4 @@ def test_hypermap_specialize_param():
expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32)))
ret = hypermap_specialize_param()
assert ret == (expected_ret, expected_ret)
assert ret == (expected_ret, list(expected_ret))