forked from mindspore-Ecosystem/mindspore
Apply indexed_slices
This commit is contained in:
parent
e03bd975a9
commit
44e74ad5aa
|
@ -45,7 +45,8 @@ FuncGraph::FuncGraph()
|
|||
hyper_param_count_(0),
|
||||
is_generated_(false),
|
||||
return_(nullptr),
|
||||
manager_(std::weak_ptr<FuncGraphManager>()) {
|
||||
manager_(std::weak_ptr<FuncGraphManager>()),
|
||||
stub_(false) {
|
||||
debug_info_ = std::make_shared<GraphDebugInfo>();
|
||||
}
|
||||
|
||||
|
|
|
@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase {
|
|||
void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs);
|
||||
bool HasEffect(const CNodePtr &cnode);
|
||||
|
||||
bool stub() const { return stub_; }
|
||||
void set_stub(bool stub) { stub_ = stub; }
|
||||
|
||||
private:
|
||||
// graph is manipulated by manager and others
|
||||
friend FuncGraphManager;
|
||||
|
@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase {
|
|||
|
||||
// CNode order which relates to origin code order
|
||||
std::list<CNodePtr> order_;
|
||||
bool stub_;
|
||||
};
|
||||
|
||||
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
|
||||
|
|
|
@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
|
|||
(*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
|
||||
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
|
||||
(*target_func_graph)->set_is_generate(func_graph->is_generated());
|
||||
(*target_func_graph)->set_stub(func_graph->stub());
|
||||
TraceManager::EndTrace();
|
||||
}
|
||||
|
||||
|
@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
|||
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
|
||||
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
||||
new_func_graph->set_is_generate(func_graph->is_generated());
|
||||
new_func_graph->set_stub(func_graph->stub());
|
||||
for (auto &item : func_graph->parameter_default_value()) {
|
||||
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "pipeline/static_analysis/param_validator.h"
|
||||
#include "operator/cc_implementations.h"
|
||||
#include "optimizer/opt.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "./common.h"
|
||||
|
@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
|
|||
}
|
||||
return item.second;
|
||||
}
|
||||
// Try best match
|
||||
py::function py_fn_subclass;
|
||||
size_t subclass_match_cnt = 0;
|
||||
for (auto &item : fn_cache_py_) {
|
||||
TypePtrList sign = item.first;
|
||||
if (sign.size() != types.size()) {
|
||||
continue;
|
||||
return py::none();
|
||||
}
|
||||
|
||||
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
if (!enable_sparse) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
ParameterPtr undetermined_param = nullptr;
|
||||
auto stub = std::make_shared<FuncGraph>();
|
||||
for (size_t i = 0; i < types.size(); ++i) {
|
||||
auto param = stub->add_parameter();
|
||||
parameters.push_back(param);
|
||||
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
|
||||
undetermined_param = param;
|
||||
}
|
||||
auto match = true;
|
||||
for (size_t i = 0; i < sign.size(); ++i) {
|
||||
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) &&
|
||||
!IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
if (undetermined_param != nullptr) {
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (size_t i = 0; i < types.size(); ++i) {
|
||||
if (types[i]->type_id() == kObjectTypeFunction) {
|
||||
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
|
||||
inputs.push_back(stub->NewCNode(call_prim));
|
||||
} else {
|
||||
inputs.push_back(parameters[i]);
|
||||
}
|
||||
}
|
||||
if (!match) {
|
||||
continue;
|
||||
}
|
||||
py_fn_subclass = item.second;
|
||||
subclass_match_cnt++;
|
||||
auto stub_output = stub->NewCNode(inputs);
|
||||
stub->set_output(stub_output);
|
||||
stub->set_stub(true);
|
||||
return stub;
|
||||
}
|
||||
if (subclass_match_cnt > 1) {
|
||||
MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass";
|
||||
}
|
||||
if (subclass_match_cnt == 1) {
|
||||
MS_LOG(DEBUG) << "Found one subclass match";
|
||||
return py_fn_subclass;
|
||||
}
|
||||
return py::none();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
||||
|
@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
|||
MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
|
||||
return func_graph;
|
||||
}
|
||||
auto stub = GenerateStubFunc(types);
|
||||
if (stub != nullptr) {
|
||||
MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString();
|
||||
return stub;
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
|
||||
<< "`, corresponding location info:\n";
|
||||
|
|
|
@ -23,8 +23,8 @@
|
|||
#include "pipeline/static_analysis/param_validator.h"
|
||||
#include "pipeline/static_analysis/prim.h"
|
||||
#include "pipeline/static_analysis/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
|
|||
return AbstractFunction::MakeAbstractFunction(jv);
|
||||
}
|
||||
|
||||
class UndeterminedShapeType {
|
||||
public:
|
||||
explicit UndeterminedShapeType(const std::string &env_str) {
|
||||
// param_name indices_shape indices_type values_shape values_type dense_shape
|
||||
// export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1
|
||||
// 2:Float32:3 1 2"
|
||||
std::vector<string> fields;
|
||||
string tmp;
|
||||
std::stringstream input(env_str);
|
||||
while (std::getline(input, tmp, ':')) {
|
||||
fields.push_back(tmp);
|
||||
}
|
||||
if (fields.size() != fields_num) {
|
||||
MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size();
|
||||
}
|
||||
|
||||
param_name_ = fields[0];
|
||||
|
||||
indices_shape_ = GetShape(fields[1]);
|
||||
indices_type_ = StringToType(fields[2]);
|
||||
|
||||
values_shape_ = GetShape(fields[3]);
|
||||
values_type_ = StringToType(fields[4]);
|
||||
|
||||
auto dense_shape_vec = GetShape(fields[5]);
|
||||
AbstractBasePtrList dense_shape_list;
|
||||
(void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list),
|
||||
[](const auto &elem) { return FromValue(elem, false); });
|
||||
dense_shape_ = dense_shape_list;
|
||||
}
|
||||
~UndeterminedShapeType() = default;
|
||||
const std::string ¶m_name() { return param_name_; }
|
||||
const std::vector<int> &indices_shape() { return indices_shape_; }
|
||||
const TypePtr &indices_type() { return indices_type_; }
|
||||
const std::vector<int> &values_shape() { return values_shape_; }
|
||||
const TypePtr &values_type() { return values_type_; }
|
||||
const AbstractBasePtrList &dense_shape() { return dense_shape_; }
|
||||
|
||||
private:
|
||||
std::string param_name_;
|
||||
std::vector<int> indices_shape_;
|
||||
TypePtr indices_type_;
|
||||
std::vector<int> values_shape_;
|
||||
TypePtr values_type_;
|
||||
AbstractBasePtrList dense_shape_;
|
||||
static const size_t fields_num;
|
||||
|
||||
std::vector<int> GetShape(const std::string &shape_str);
|
||||
};
|
||||
std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) {
|
||||
std::vector<int> ret;
|
||||
std::istringstream iss(shape_str);
|
||||
int elem;
|
||||
while (iss.good()) {
|
||||
iss >> elem;
|
||||
ret.emplace_back(elem);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
const size_t UndeterminedShapeType::fields_num = 6;
|
||||
|
||||
std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs;
|
||||
void InitUndeterminedFromEnv(const std::string &sparse_shape_types) {
|
||||
std::string tmp;
|
||||
std::stringstream input(sparse_shape_types);
|
||||
g_undetermined_configs.clear();
|
||||
while (std::getline(input, tmp, ';')) {
|
||||
auto config = UndeterminedShapeType(tmp);
|
||||
g_undetermined_configs.insert(std::make_pair(config.param_name(), config));
|
||||
MS_LOG(DEBUG) << "Undetermined config from env: " << tmp;
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
|
||||
}
|
||||
|
||||
if (!key->sparse_grad().empty()) {
|
||||
// Will be fixed once undetermined type ready
|
||||
if (g_undetermined_configs.empty()) {
|
||||
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES");
|
||||
MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types;
|
||||
if (sparse_shape_types.empty()) {
|
||||
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2";
|
||||
}
|
||||
InitUndeterminedFromEnv(sparse_shape_types);
|
||||
}
|
||||
|
||||
auto shape_types = g_undetermined_configs.find(key->sparse_grad());
|
||||
if (shape_types == g_undetermined_configs.end()) {
|
||||
MS_LOG(EXCEPTION) << "Param " << key->ToString()
|
||||
<< " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES";
|
||||
}
|
||||
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString();
|
||||
AbstractBasePtrList sparse_list;
|
||||
// indices
|
||||
auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.indices_type());
|
||||
auto indices =
|
||||
std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types->second.indices_shape()));
|
||||
sparse_list.emplace_back(indices);
|
||||
// values
|
||||
auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.values_type());
|
||||
auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types->second.values_shape()));
|
||||
sparse_list.emplace_back(dout);
|
||||
// dense_shape
|
||||
sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types->second.dense_shape()));
|
||||
return std::make_shared<AbstractTuple>(sparse_list);
|
||||
}
|
||||
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) {
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
if (enable_sparse && dflt->isa<AbstractTensor>()) {
|
||||
auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
|
||||
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
|
||||
}
|
||||
|
||||
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
|
||||
return dflt;
|
||||
}
|
||||
|
@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
if (type->type_id() != kObjectTypeRefKey) {
|
||||
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
|
||||
}
|
||||
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad());
|
||||
return ret;
|
||||
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
|
|
|
@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
|
|||
}
|
||||
|
||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor {
|
|||
|
||||
// G
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
|
||||
return nullptr;
|
||||
}
|
||||
// Do not inline GraphKernel to Cell.
|
||||
|
|
|
@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
std::string env = common::GetEnv("SLICE_ENV");
|
||||
if (!env.empty()) {
|
||||
MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env;
|
||||
abstract::InitUndeterminedFromEnv(env);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
ValuePtr value = param_value->value();
|
||||
constexpr bool broaden = true;
|
||||
AbstractBasePtr ptr = abstract::FromValue(value, broaden);
|
||||
ptr->set_sparse_grad(param_value->sparse_grad());
|
||||
ptr->set_has_indexed_slices_grad(param_value->has_indexed_slices_grad());
|
||||
|
||||
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
||||
args_spec.push_back(ptr);
|
||||
|
|
|
@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
|
||||
"Set the GraphKernel switch to on or off.")
|
||||
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
|
||||
.def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.")
|
||||
.def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse.");
|
||||
.def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.")
|
||||
.def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
|
|
@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup},
|
||||
{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"cconv", CconvPass},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||
{"add_control_depend", AddControlDependPass}};
|
||||
|
||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"add_control_depend", AddControlDependPass},
|
||||
{"opt_control", ControlGroup},
|
||||
{"opt_prepare", PrepareGroup},
|
||||
{"cconv", CconvPass}};
|
||||
std::vector<PassItem> kGePasses = {
|
||||
{"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass},
|
||||
{"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup},
|
||||
{"cconv", CconvPass}};
|
||||
|
||||
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
|
||||
} // namespace pipeline
|
||||
|
|
|
@ -146,37 +146,35 @@ MethodMap &GetMethodMap() {
|
|||
}},
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||
{"__mod__", std::string("mod")}, // C.mod
|
||||
{"__pow__", std::string("pow_")}, // C.pow
|
||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||
{"__eq__", std::string("eq")}, // C.eq
|
||||
{"__ne__", std::string("ne")}, // C.ne
|
||||
{"__lt__", std::string("lt")}, // C.lt
|
||||
{"__gt__", std::string("gt")}, // C.gt
|
||||
{"__le__", std::string("le")}, // C.le
|
||||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"__matmul__", prim::kPrimDot}, // P.dot,
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||
{"__mod__", std::string("mod")}, // C.mod
|
||||
{"__pow__", std::string("pow_")}, // C.pow
|
||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||
{"__eq__", std::string("eq")}, // C.eq
|
||||
{"__ne__", std::string("ne")}, // C.ne
|
||||
{"__lt__", std::string("lt")}, // C.lt
|
||||
{"__gt__", std::string("gt")}, // C.gt
|
||||
{"__le__", std::string("le")}, // C.le
|
||||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"__matmul__", prim::kPrimDot}, // P.dot,
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
}},
|
||||
{kObjectTypeIndexedSlicesType,
|
||||
{
|
||||
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
|
||||
{"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values
|
||||
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
|
||||
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
|
||||
|
|
|
@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const {
|
|||
AbstractBasePtr AbstractBase::Broaden() const {
|
||||
AbstractBasePtr clone = Clone();
|
||||
clone->set_value(kAnyValue);
|
||||
clone->set_sparse_grad(sparse_grad_);
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const {
|
|||
MS_EXCEPTION_IF_NULL(type_);
|
||||
MS_EXCEPTION_IF_NULL(shape_);
|
||||
buffer << type_name() << "("
|
||||
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
|
||||
<< " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")";
|
||||
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
|
|||
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
if (*this == *other) {
|
||||
auto ret = shared_from_base<AbstractBase>();
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return ret;
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
auto value_self = GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_self);
|
||||
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
|
||||
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
|
||||
if (res_value == value_self) {
|
||||
auto ret = shared_from_base<AbstractBase>();
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return ret;
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return ret;
|
||||
return std::make_shared<AbstractScalar>(res_value, res_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractType::Clone() const {
|
||||
|
@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
|||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
}
|
||||
if (*this == *other) {
|
||||
if (sparse_grad() == other->sparse_grad()) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
auto element = element_->Join(other_tensor->element_);
|
||||
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
|
||||
auto ret = std::make_shared<AbstractTensor>(element, shape);
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return ret;
|
||||
return std::make_shared<AbstractTensor>(element, shape);
|
||||
}
|
||||
|
||||
bool AbstractTensor::operator==(const AbstractTensor &other) const {
|
||||
|
@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const {
|
|||
ShapePtr shp = shape();
|
||||
clone->set_shape(shp->Clone());
|
||||
clone->set_value(GetValueTrack());
|
||||
clone->set_sparse_grad(sparse_grad());
|
||||
clone->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const {
|
|||
auto shp = shape();
|
||||
broaden->set_shape(shp->Clone());
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_sparse_grad(sparse_grad());
|
||||
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
|
@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
|
|||
shp->Broaden();
|
||||
broaden->set_shape(shp);
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_sparse_grad(sparse_grad());
|
||||
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
|
@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const {
|
|||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
buffer << type_name() << "("
|
||||
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
|
||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
|
||||
<< " has_indexed_slices_grad " << has_indexed_slices_grad() << ")";
|
||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ class AbstractBase : public Base {
|
|||
public:
|
||||
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
|
||||
const BaseShapePtr &shape = kNoShape)
|
||||
: value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {}
|
||||
: value_(value), type_(type), shape_(shape) {}
|
||||
~AbstractBase() override = default;
|
||||
MS_DECLARE_PARENT(AbstractBase, Base)
|
||||
|
||||
|
@ -53,17 +53,11 @@ class AbstractBase : public Base {
|
|||
|
||||
virtual bool operator==(const AbstractBase &other) const;
|
||||
void set_value(const ValuePtr &value) { value_ = value; }
|
||||
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
|
||||
void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) {
|
||||
has_indexed_slices_grad_ = has_indexed_slices_grad;
|
||||
}
|
||||
void set_type(const TypePtr &type) { type_ = type; }
|
||||
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
|
||||
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
|
||||
const std::string &value_desc() const { return value_desc_; }
|
||||
ValuePtr GetValueTrack() const { return value_; }
|
||||
const std::string &sparse_grad() const { return sparse_grad_; }
|
||||
const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
|
||||
TypePtr GetTypeTrack() const { return type_; }
|
||||
BaseShapePtr GetShapeTrack() const { return shape_; }
|
||||
|
||||
|
@ -91,8 +85,6 @@ class AbstractBase : public Base {
|
|||
TypePtr type_;
|
||||
BaseShapePtr shape_;
|
||||
std::string value_desc_; // store initial value description for error report
|
||||
std::string sparse_grad_;
|
||||
bool has_indexed_slices_grad_;
|
||||
};
|
||||
|
||||
class AbstractScalar : public AbstractBase {
|
||||
|
|
|
@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(ret_base);
|
||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString();
|
||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
|
||||
<< ", is stub: " << fg->stub();
|
||||
if (fg->stub()) {
|
||||
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
|
||||
}
|
||||
return std::make_shared<EvalResult>(ret_base, nullptr);
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "pipeline/static_analysis/static_analysis.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -59,6 +60,13 @@ class Evaluator : public Base {
|
|||
}
|
||||
|
||||
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
if (!enable_sparse) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
|
||||
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
return true;
|
||||
|
|
|
@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
using mindspore::parse::PyObjectWrapper;
|
||||
|
||||
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
|
||||
if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
|
||||
|
@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
|
||||
auto ret_abstract = AbstractEval(args_spec_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
}
|
||||
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
|
||||
|
||||
ScopePtr scope = kDefaultScope;
|
||||
if (out_conf != nullptr) {
|
||||
scope = out_conf->node()->scope();
|
||||
|
@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
|||
} // end anonymous namespace
|
||||
|
||||
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
||||
|
||||
|
@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
|
|||
}
|
||||
|
||||
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
|
||||
if (nargs_ != args.size()) {
|
||||
|
@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
auto ret = std::make_shared<AbstractScalar>(type);
|
||||
auto ref_value = ref_abs->ref();
|
||||
MS_EXCEPTION_IF_NULL(ref_value);
|
||||
ret->set_sparse_grad(ref_value->sparse_grad());
|
||||
ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad());
|
||||
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
|
@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
x = SensitivityTransform(x);
|
||||
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
||||
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
||||
abs_scalar->set_sparse_grad(x->sparse_grad());
|
||||
abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad());
|
||||
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
};
|
||||
|
@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
|||
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse_flag = context->enable_sparse_flag();
|
||||
if (enable_sparse_flag) {
|
||||
auto ret_abstract = AbstractEval(args_spec_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
auto ret_abstract = AbstractEval(args_spec_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
// Inputs: data, item
|
||||
if (args_spec_list.size() != 2) {
|
||||
|
|
|
@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
|
|||
|
||||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
void InitUndeterminedFromEnv(const std::string &sparse_shape_types);
|
||||
|
||||
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
|
|
@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
|
|||
AbstractFunctionPtr func = real_a->GetUnique();
|
||||
SpecializeStatusCode errcode;
|
||||
ScopeGuard scope_guard(node->scope());
|
||||
AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode);
|
||||
AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode);
|
||||
if (repl == nullptr) {
|
||||
if (errcode == kSpecializeFindUniqueArgvalDead) {
|
||||
const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
|
||||
|
@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
|
|||
return repl;
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
|
||||
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
|
||||
const AbstractFunctionPtr &func,
|
||||
const AbstractBasePtrList &args,
|
||||
SpecializeStatusCode *errcode) {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
|
@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
|
|||
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
|
||||
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
|
||||
<< ", graph: " << context->func_graph()->get_return()->DebugString();
|
||||
if (context->func_graph()->stub()) {
|
||||
MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
|
||||
<< ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
|
||||
<< ", " << node->ToString();
|
||||
return node;
|
||||
}
|
||||
FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
|
||||
v->set_flag(kFuncGraphFlagUndetermined, false);
|
||||
return BuildValueNode(v, abs);
|
||||
}
|
||||
|
||||
|
@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
|
||||
return kSpecializeSuccess;
|
||||
} else if (choices->empty()) {
|
||||
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
|
||||
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
|
||||
<< func->type_name();
|
||||
return kSpecializeFindUniqueArgvalDead;
|
||||
} else {
|
||||
if (IsPolyFunc(func, argvals)) {
|
||||
|
|
|
@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
// Build a specialized node from given argvals;
|
||||
AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
|
||||
const AbstractBasePtrList &argvals);
|
||||
AnfNodePtr BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
|
||||
const AbstractBasePtrList &args, SpecializeStatusCode *errcode);
|
||||
AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
|
||||
const AbstractFunctionPtr &func, const AbstractBasePtrList &args,
|
||||
SpecializeStatusCode *errcode);
|
||||
|
||||
// Find the unique argument values which can be used to specialize a primitive or graph function.
|
||||
SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval,
|
||||
|
|
|
@ -89,7 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
max_device_memory_ = kDefaultMaxDeviceMemory;
|
||||
print_file_path_ = "";
|
||||
enable_graph_kernel_ = false;
|
||||
enable_sparse_flag_ = false;
|
||||
enable_sparse_ = false;
|
||||
}
|
||||
|
||||
std::shared_ptr<MsContext> MsContext::GetInstance() {
|
||||
|
|
|
@ -161,8 +161,8 @@ class MsContext {
|
|||
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
|
||||
bool enable_graph_kernel() const { return enable_graph_kernel_; }
|
||||
|
||||
bool enable_sparse_flag() const { return enable_sparse_flag_; }
|
||||
void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; }
|
||||
bool enable_sparse() const { return enable_sparse_; }
|
||||
void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; }
|
||||
|
||||
private:
|
||||
MsContext(const std::string &backend_policy, const std::string &target);
|
||||
|
@ -207,7 +207,7 @@ class MsContext {
|
|||
float max_device_memory_;
|
||||
std::string print_file_path_;
|
||||
bool enable_graph_kernel_;
|
||||
bool enable_sparse_flag_;
|
||||
bool enable_sparse_;
|
||||
};
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,18 +51,13 @@ class Parameter:
|
|||
requires_grad (bool): True if the parameter requires gradient. Default: True.
|
||||
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
|
||||
broadcast and gradients communication would not be applied on parameters. Default: False.
|
||||
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
|
||||
has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false.
|
||||
"""
|
||||
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False,
|
||||
sparse_grad="", has_indexed_slices_grad=False):
|
||||
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
|
||||
self._value = ParamValue()
|
||||
self.set_parameter_data(default_input)
|
||||
self.name = name
|
||||
self.requires_grad = requires_grad
|
||||
self.layerwise_parallel = layerwise_parallel
|
||||
self.sparse_grad = sparse_grad
|
||||
self.has_indexed_slices_grad = has_indexed_slices_grad
|
||||
self._is_init = False
|
||||
self._sliced = False
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
|
@ -177,28 +172,6 @@ class Parameter:
|
|||
raise TypeError("`requires_grad` parameter must be bool type")
|
||||
self._value.requires_grad = value
|
||||
|
||||
@property
|
||||
def sparse_grad(self):
|
||||
"""Return whether the parameter's gradient is sparse."""
|
||||
return self._value.sparse_grad
|
||||
|
||||
@sparse_grad.setter
|
||||
def sparse_grad(self, value=""):
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("`sparse_grad` parameter must be str type")
|
||||
self._value.sparse_grad = value
|
||||
|
||||
@property
|
||||
def has_indexed_slices_grad(self):
|
||||
"""Return whether the parameter's gradient is indexed_slices."""
|
||||
return self._value.has_indexed_slices_grad
|
||||
|
||||
@has_indexed_slices_grad.setter
|
||||
def has_indexed_slices_grad(self, value=False):
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("`has_indexed_slices_grad` parameter must be bool type")
|
||||
self._value.has_indexed_slices_grad = value
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.default_input
|
||||
|
|
|
@ -367,14 +367,6 @@ class _Context:
|
|||
def check_bprop(self, check_bprop_flag):
|
||||
self._context_handle.set_check_bprop_flag(check_bprop_flag)
|
||||
|
||||
@property
|
||||
def enable_sparse(self):
|
||||
return self._context_handle.get_enable_sparse_flag()
|
||||
|
||||
@enable_sparse.setter
|
||||
def enable_sparse(self, enable_sparse_flag):
|
||||
self._context_handle.set_enable_sparse_flag(enable_sparse_flag)
|
||||
|
||||
@property
|
||||
def max_device_memory(self):
|
||||
return self._context_handle.get_max_device_memory()
|
||||
|
@ -408,6 +400,13 @@ class _Context:
|
|||
full_file_name = print_file_path
|
||||
self._context_handle.set_print_file_path(full_file_name)
|
||||
|
||||
@property
|
||||
def enable_sparse(self):
|
||||
return self._context_handle.get_enable_sparse()
|
||||
|
||||
@enable_sparse.setter
|
||||
def enable_sparse(self, enable_sparse):
|
||||
self._context_handle.set_enable_sparse(enable_sparse)
|
||||
|
||||
def check_input_format(x):
|
||||
import re
|
||||
|
@ -601,7 +600,7 @@ def set_context(**kwargs):
|
|||
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
|
||||
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
|
||||
suffix to the file.
|
||||
enable_sparse (bool): Whether to enable sparse feature. Default: False.
|
||||
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
|
|
@ -162,8 +162,8 @@ class Adam(Optimizer):
|
|||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
|
||||
The sparse feature is under continuous development. The sparse
|
||||
behavior is currently performed on the CPU.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -72,8 +72,8 @@ class FTRL(Optimizer):
|
|||
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
|
||||
|
||||
Note:
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
|
||||
The sparse feature is under continuous development. The sparse
|
||||
behavior is currently performed on the CPU.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -91,8 +91,8 @@ class LazyAdam(Optimizer):
|
|||
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
|
||||
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
|
||||
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
|
||||
The sparse behavior, to be notice, is not equivalent to the
|
||||
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
|
||||
continuous development. The sparse behavior is currently performed on the CPU.
|
||||
|
||||
|
|
|
@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer):
|
|||
<http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_.
|
||||
|
||||
Note:
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
|
||||
The sparse feature is under continuous development. The sparse
|
||||
behavior is currently performed on the CPU.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices')
|
|||
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||
is_indexed_slices = Primitive('IsIndexedSlices')
|
||||
|
||||
|
||||
tensor_operator_registry.register('__add__', tensor_add)
|
||||
|
|
|
@ -36,6 +36,8 @@ from mindspore._checkparam import Rel
|
|||
from mindspore.nn import Optimizer
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
|
||||
reduce_sum = P.ReduceSum()
|
||||
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||
transpose = P.Transpose()
|
||||
|
@ -44,7 +46,6 @@ reshape = P.Reshape()
|
|||
size_op = P.Size()
|
||||
invert_permutation = P.InvertPermutation()
|
||||
logical_and = P.LogicalAnd()
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
|
||||
@constexpr
|
||||
def _generate_shape_index(out_shape, indices_shape, axis):
|
||||
|
@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self):
|
|||
|
||||
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
|
||||
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Undetermined", "Bool")
|
||||
def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||
if gradient.is_indexed_slices():
|
||||
return gradient.values()
|
||||
"Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool")
|
||||
def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param,
|
||||
m, v, gradient, decay_flag):
|
||||
return gradient.values()
|
||||
|
||||
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
|
||||
m, v, gradient, decay_flag):
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
|
@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices():
|
|||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
ret = (IndexedSlices(indices, values, self.dense_shape),)
|
||||
return ret[0].is_indexed_slices()
|
||||
return ret[0]
|
||||
indices = Tensor([[0, 0], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
MakeIndexedSlices()(indices, values)
|
||||
|
@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
|
|||
self.network = network
|
||||
def construct(self, x, y):
|
||||
grad = grad_all(self.network)(x, y)
|
||||
return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices()
|
||||
return grad, grad[0], grad[1]
|
||||
class SparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseGatherV2, self).__init__()
|
||||
|
@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
|||
weights = self.weights
|
||||
grad = grad_by_list(self.network, weights)(x)
|
||||
x = grad[0]
|
||||
return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape()
|
||||
return x, x.values(), x.indices(), x.dense_shape()
|
||||
class SparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseGatherV2, self).__init__()
|
||||
self.sparse_gatherv2 = MySparseGatherV2()
|
||||
self.axis = 0
|
||||
self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)),
|
||||
name="params", has_indexed_slices_grad=True)
|
||||
self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params")
|
||||
def construct(self, indices):
|
||||
return self.sparse_gatherv2(self.params, indices, self.axis)
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
|
@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
|||
network(indices)
|
||||
|
||||
|
||||
def test_indexed_slices_is_indexed_slices():
|
||||
class MakeIndexedSlices(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MakeIndexedSlices, self).__init__()
|
||||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
indexed_slices = IndexedSlices(indices, values, self.dense_shape)
|
||||
ret = indexed_slices.is_indexed_slices()
|
||||
return ret
|
||||
indices = Tensor([[0, 0], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
MakeIndexedSlices()(indices, values)
|
||||
|
||||
|
||||
def test_indexed_slices_env_get():
|
||||
class Loss(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -271,7 +262,7 @@ def test_indexed_slices_env_get():
|
|||
class NetWithSparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True)
|
||||
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
|
||||
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
|
||||
self.gatherv2 = MySparseGatherV2()
|
||||
self.axis = 0
|
||||
|
|
|
@ -17,12 +17,13 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import Tensor, Parameter, context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell):
|
|||
""" NetWithSparseGatherV2 definition """
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)),
|
||||
name="weight1", sparse_grad="sparse_key_w1")
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
|
||||
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
|
||||
self.axis = 0
|
||||
self.gather = P.SparseGatherV2()
|
||||
|
|
|
@ -27,6 +27,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
|
||||
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
|
@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse():
|
|||
class NetWithSparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1")
|
||||
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
|
||||
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
|
||||
self.gatherv2 = P.SparseGatherV2()
|
||||
self.axis = 0
|
||||
|
|
|
@ -17,12 +17,13 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import Tensor, Parameter, context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import FTRL
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell):
|
|||
""" NetWithSparseGatherV2 definition """
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)),
|
||||
name="weight1", sparse_grad="sparse_key_w1")
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
|
||||
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
|
||||
self.axis = 0
|
||||
self.gather = P.SparseGatherV2()
|
||||
|
|
|
@ -17,12 +17,13 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import Tensor, Parameter, context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import LazyAdam
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell):
|
|||
""" NetWithSparseGatherV2 definition """
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)),
|
||||
name="weight1", sparse_grad="sparse_key_w1")
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
|
||||
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
|
||||
self.axis = 0
|
||||
self.gather = P.SparseGatherV2()
|
||||
|
|
|
@ -17,12 +17,13 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import Tensor, Parameter, context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import ProximalAdagrad
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell):
|
|||
""" NetWithSparseGatherV2 definition """
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1",
|
||||
sparse_grad="sparse_key_w1")
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
|
||||
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="weight2")
|
||||
self.axis = 0
|
||||
self.gather = P.SparseGatherV2()
|
||||
|
|
|
@ -53,4 +53,4 @@ def test_hypermap_specialize_param():
|
|||
|
||||
expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32)))
|
||||
ret = hypermap_specialize_param()
|
||||
assert ret == (expected_ret, expected_ret)
|
||||
assert ret == (expected_ret, list(expected_ret))
|
||||
|
|
Loading…
Reference in New Issue