Add bprop_return_sparse flag for sparse bprop primitive and remove context option: 'enable_sparse'.

This commit is contained in:
Zhang Qinghua 2022-05-17 09:26:41 +08:00
parent da2a22dc22
commit d7762f1c8d
58 changed files with 157 additions and 194 deletions

View File

@ -59,6 +59,7 @@ using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
using mindspore::abstract::AbstractUndetermined;
using mindspore::abstract::EnvSetSparseResultMgr;
using mindspore::abstract::FuncGraphAbstractClosure;
void HyperMap::Init() {
@ -645,19 +646,12 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
}
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
const std::vector<AnfNodePtr> &weight_args) {
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad) {
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
k_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
AnfNodePtr weights_node = nullptr;
AnfNodePtr position_node = nullptr;
if (weights != nullptr) {
weights_node = weights;
} else if (!weight_args.empty()) {
weights_node = k_child->NewCNodeInOrder(weight_args);
}
if (position != nullptr) {
position_node = position;
}
@ -673,7 +667,7 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weigh
auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
GradByParameter(k_child, f_app, bprop, weights_node, position_node, enable_tuple_grad);
GradByParameter(k_child, f_app, bprop, weights, position_node, enable_tuple_grad);
return k_child;
}
@ -740,6 +734,38 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
}
namespace {
// Check if primal func graph has the primitive returned sparse result in its bprop().
void CheckPrimBpropReturnSparse(const FuncGraphPtr &primal_graph) {
bool has_sparse_bprop_prim = false;
(void)TopoSort(primal_graph->return_node(), SuccDeeperSimple,
[&has_sparse_bprop_prim](const AnfNodePtr &node) -> IncludeType {
MS_EXCEPTION_IF_NULL(node);
if (has_sparse_bprop_prim) {
return EXCLUDE;
}
auto prim = GetCNodePrimitive(node);
if (prim != nullptr) {
auto do_signature = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
if (do_signature != nullptr) {
prim = dyn_cast<Primitive>(do_signature->function());
}
bool sparse_bprop = GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE);
if (sparse_bprop) {
MS_LOG(DEBUG) << "prim: " << prim->ToString() << " has attr 'bprop_return_sparse'";
has_sparse_bprop_prim = true;
return EXCLUDE;
}
}
return FOLLOW;
});
if (has_sparse_bprop_prim) {
primal_graph->set_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP, true);
EnvSetSparseResultMgr::GetInstance().Set(true);
}
}
} // namespace
// Generate the graph.
FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.empty()) {
@ -761,6 +787,10 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
FuncGraphPtr forward_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(forward_graph);
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
// Check if primal func graph has the primitive returned sparse result in its bprop().
CheckPrimBpropReturnSparse(forward_graph);
FuncGraphPtr grad_fg = nullptr;
{
TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));

View File

@ -150,8 +150,7 @@ class GradOperation : public MetaFuncGraph {
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad_first,
const std::vector<AnfNodePtr> &weight_args = {});
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad_first);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;

View File

@ -353,6 +353,10 @@ AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePt
bool result = false;
if (!CheckCmpValid(cmp)) {
auto cmp_type = cmp->BuildType();
MS_EXCEPTION_IF_NULL(cmp_type);
MS_LOG(ERROR) << "cmp: " << cmp->ToString() << ", cmp_type: " << cmp_type->ToString()
<< ", cmp_type_id: " << TypeIdToType(cmp_type->type_id());
MS_EXCEPTION(TypeError) << "isinstance() arg 2 must be a type or tuple of types.";
}

View File

@ -161,9 +161,7 @@ Status GatherInfo::GetAttrs() {
dynamic_shape_indices_ = true;
}
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) {
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
dynamic_shape_indices_ = true;
}
#endif

View File

@ -602,16 +602,19 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
AbstractBasePtrList bparams;
bparams.push_back(SensitivityTransform(primal_func_));
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
// Check if primal func graph has the primitive returned sparse result in its bprop().
auto real_primal_func = dyn_cast<FuncGraphAbstractClosure>(primal_func_);
MS_EXCEPTION_IF_NULL(real_primal_func);
FuncGraphPtr primal_func_graph = real_primal_func->func_graph();
MS_EXCEPTION_IF_NULL(primal_func_graph);
bool has_sparse_bprop_prim = primal_func_graph->has_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP);
(void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(bparams),
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg_spec);
if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
[&has_sparse_bprop_prim](const AbstractBasePtr &arg_abs) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg_abs);
if (has_sparse_bprop_prim && arg_abs->isa<AbstractTensor>()) {
return std::make_shared<AbstractUndetermined>();
}
return SensitivityTransform(arg_spec);
return SensitivityTransform(arg_abs);
});
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
AbstractFunctionPtr bprop =

View File

@ -66,21 +66,14 @@ class Evaluator : public Base {
}
virtual EvalResultPtr EvalUndeterminedArgs(const AbstractBasePtrList &args_abs_list) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (!enable_sparse) {
return nullptr;
}
auto is_abstract = std::any_of(args_abs_list.begin(), args_abs_list.end(), [](auto &arg) -> bool {
auto is_undetermined = std::any_of(args_abs_list.begin(), args_abs_list.end(), [](auto &arg) -> bool {
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
return true;
}
return false;
});
if (is_abstract) {
MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result";
if (is_undetermined) {
MS_LOG(DEBUG) << "Eval " << identifier_ << " return undetermined abstract result";
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>());
}
return nullptr;

View File

@ -876,13 +876,6 @@ AbstractBasePtr MakePyInferRes2Abstract(const py::object &output) {
// Return monad abstract if it is monad type.
return ToMonadAbstract(type_obj);
} else {
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (enable_sparse) {
return std::make_shared<abstract::AbstractUndetermined>();
}
MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
}
}

View File

@ -381,8 +381,10 @@ void AnalysisEngine::ClearEvaluatorCache() {
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
evaluator->evaluator_cache_mgr()->Clear();
}
// Release Exception to avoid hup at exit.
// Release exception to avoid hup at exit.
StaticAnalysisException::Instance().ClearException();
// Reset the EnvironGet sparse option.
EnvSetSparseResultMgr::GetInstance().Set(false);
}
void AnalysisEngine::Clear() {

View File

@ -78,7 +78,6 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)

View File

@ -1909,9 +1909,7 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph &g
} else if (cnode->isa<CNode>() && (common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName)) {
MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
FinalizePsCache(
"All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
"the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "
"context setting in parameter server training mode.");
"All EmbeddingLookup kernels whose input indices are from dataset must enable cache at the same time.");
}
}
}

View File

@ -34,6 +34,7 @@
#include "ir/value.h"
#include "ir/tensor.h"
#include "abstract/dshape.h"
#include "abstract/utils.h"
#include "utils/shape_utils.h"
namespace mindspore {

View File

@ -117,10 +117,7 @@ AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePt
}
MS_LOG(DEBUG) << "key: " << key->ToString() << ", value: " << default_value->ToString();
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (enable_sparse && default_value->isa<AbstractTensor>()) {
if (default_value->isa<AbstractTensor>() && EnvSetSparseResultMgr::GetInstance().Get()) {
auto tensor_value = default_value->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(tensor_value);
return std::make_shared<AbstractUndetermined>(tensor_value->element()->Clone(), tensor_value->shape()->Clone());
@ -149,7 +146,7 @@ AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePt
// args: Three objects of a subclass of AbstractBase, env, key, value.
CheckArgsSize(primitive->name(), args_spec_list, kSizeThree);
auto key = args_spec_list[1];
auto key = args_spec_list[kIndexOne];
ValuePtr key_value_ptr = key->GetValueTrack();
MS_EXCEPTION_IF_NULL(key_value_ptr);
auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
@ -160,7 +157,11 @@ AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePt
auto expected = key_value_track->abstract();
MS_EXCEPTION_IF_NULL(expected);
MS_LOG(DEBUG) << "key: " << key->ToString() << ", value: " << args_spec_list[kIndexTwo]->ToString();
auto value = args_spec_list[kIndexTwo];
MS_LOG(DEBUG) << "key: " << key->ToString() << ", value: " << value->ToString();
if (value->isa<AbstractUndetermined>() && !value->isa<AbstractTensor>()) {
EnvSetSparseResultMgr::GetInstance().Set(true);
}
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
}

View File

@ -307,14 +307,7 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
// Return monad abstract if it is monad type.
return MakeMonadAbstract(type->cast<MonadTypePtr>());
} else {
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (enable_sparse) {
return std::make_shared<abstract::AbstractUndetermined>();
}
MS_LOG(EXCEPTION) << "evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
MS_LOG(EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
}
}
} // namespace abstract

View File

@ -58,6 +58,24 @@ void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVec
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
MS_CORE_API AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
MS_CORE_API AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
class MS_CORE_API EnvSetSparseResultMgr {
public:
static EnvSetSparseResultMgr &GetInstance() noexcept {
static EnvSetSparseResultMgr instance;
return instance;
}
EnvSetSparseResultMgr(const EnvSetSparseResultMgr &) = delete;
EnvSetSparseResultMgr &operator=(const EnvSetSparseResultMgr &) = delete;
~EnvSetSparseResultMgr() = default;
bool Get() const { return env_set_sparse_result_; }
void Set(bool env_set_sparse_result) { env_set_sparse_result_ = env_set_sparse_result; }
private:
EnvSetSparseResultMgr() = default;
bool env_set_sparse_result_{false};
};
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_

View File

@ -82,6 +82,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUE[] = "ignore_value";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_SPARSE_BPROP[] = "sparse_bprop";
const char FUNC_GRAPH_FLAG_NO_INLINE[] = "no_inline";
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
const char FUNC_GRAPH_FLAG_CORE[] = "core";

View File

@ -163,8 +163,8 @@ std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root) {
return todo;
}
// PushSuccessors push cnode inputs to a vector as successors for topo sort.
static void PushSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
// To get CNode inputs to a vector as successors for TopoSort().
static void FetchCNodeSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
auto &inputs = cnode->inputs();
vecs->reserve(vecs->size() + inputs.size());
@ -194,7 +194,7 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
return vecs;
} else if (node->func_graph() != nullptr) {
if (node->isa<CNode>()) {
PushSuccessors(node->cast<CNodePtr>(), &vecs);
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
}
return vecs;
}
@ -217,7 +217,7 @@ std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node) {
return vecs;
} else {
if (node->isa<CNode>()) {
PushSuccessors(node->cast<CNodePtr>(), &vecs);
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
}
return vecs;
}
@ -227,7 +227,7 @@ std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node) {
std::vector<AnfNodePtr> vecs;
auto cnode = dyn_cast<CNode>(node);
if (cnode != nullptr) {
PushSuccessors(cnode, &vecs);
FetchCNodeSuccessors(cnode, &vecs);
}
return vecs;
}
@ -251,7 +251,7 @@ std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &
}
}
}
PushSuccessors(cnode, &vecs);
FetchCNodeSuccessors(cnode, &vecs);
}
return vecs;
}
@ -275,7 +275,7 @@ std::vector<AnfNodePtr> SuccWithFilter(const GraphFilterFunc &graph_filter, cons
return vecs;
} else {
if (node->isa<CNode>()) {
PushSuccessors(node->cast<CNodePtr>(), &vecs);
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
}
return vecs;
}

View File

@ -28,13 +28,6 @@ abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
}
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();

View File

@ -31,6 +31,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop";
inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result";
inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header";
inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip";
inline const char GRAPH_FLAG_BPROP_RETURN_SPARSE[] = "bprop_return_sparse";
// method names of python primitive called from c++ source code
// 1. infer method name of class 'PrimitiveWithInfer'

View File

@ -92,7 +92,6 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE, kDefaultMempoolBlockSize);
set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);

View File

@ -77,7 +77,6 @@ enum MsCtxParam : unsigned {
MS_CTX_ENABLE_PYNATIVE_HOOK,
MS_CTX_ENABLE_PYNATIVE_INFER,
MS_CTX_ENABLE_REDUCE_PRECISION,
MS_CTX_ENABLE_SPARSE,
MS_CTX_ENABLE_TASK_SINK,
MS_CTX_IR_FUSION_FLAG,
MS_CTX_IS_MULTI_GRAPH_SINK,

View File

@ -642,9 +642,9 @@ def _check_target_specific_cfgs(device, arg_key):
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int,
env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int,
load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
max_device_memory=str, print_file_path=str, max_call_depth=int, env_config_path=str,
graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool,
grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
def set_context(**kwargs):
"""
Set context for running environment.
@ -704,8 +704,6 @@ def set_context(**kwargs):
| +------------------------------+----------------------------+
| | max_call_depth | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | enable_sparse | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | grad_for_scalar | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | enable_compile_cache | CPU/GPU/Ascend |
@ -827,9 +825,6 @@ def set_context(**kwargs):
The max_call_depth parameter needs to be set when the nested call is too deep or the number
of subgraphs is too large. If max_call_depth is set larger than before, the system max stack depth should be
set larger too, otherwise a `core dumped` exception may be raised because of system stack overflow.
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
For details of sparsity and sparse tensor, please check
`sparse tensor <https://www.mindspore.cn/tutorials/en/master/beginner/tensor.html#sparse-tensor>`_.
grad_for_scalar (bool): Whether to get gradient for scalar. Default: False.
When grad_for_scalar is set to True, the function's scalar input can be derived.
The default value is False. Because the back-end does not support scaling operations currently,
@ -867,7 +862,6 @@ def set_context(**kwargs):
>>> set_context(max_device_memory="3.5GB")
>>> set_context(mempool_block_size="1GB")
>>> set_context(print_file_path="print.pb")
>>> set_context(enable_sparse=True)
>>> set_context(max_call_depth=80)
>>> set_context(env_config_path="./env_config.json")
>>> set_context(auto_tune_mode="GA,RL")

View File

@ -25,7 +25,6 @@ from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id
from mindspore import context
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
@ -339,12 +338,6 @@ class EmbeddingLookup(Cell):
self.cache_enable = True
if _is_role_worker():
self.vocab_size = self.vocab_cache_size
if context.get_context("enable_sparse") != self.sparse:
raise ValueError(f"For '{self.cls_name}', the value of parameter 'sparse' must be same for all "
f"kernels and equal the value of 'enable_sparse' in context setting in "
f"parameter server cache mode, but got value of parameter 'sparse': {self.sparse}"
f" and the 'enable_sparse' in context setting: "
f"{context.get_context('enable_sparse')}.")
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
"""PS embeddingLookup cache enable set."""

View File

@ -898,10 +898,6 @@ class EmbeddingLookupThor(Cell):
self.cache_enable = True
if _is_role_worker():
self.vocab_size = self.vocab_cache_size
if context.get_context("enable_sparse") != self.sparse:
raise ValueError(f"For '{self.cls_name}', the 'sparse' must be equal to the 'enable_sparse' "
f"in context setting in parameter server cache mode, but got 'sparse': "
f"{self.sparse}, 'enable_sparse': {context.get_context('enable_sparse')}.")
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
"""PS embeddingLookup cache enable set."""

View File

@ -26,7 +26,6 @@ from ..functional import broadcast_gradient_args
from .. import functional as F
from .grad_base import bprop_getters
from ..primitive import constexpr
from ... import context
from ...common import dtype as mstype
from ...common.tensor import RowTensor
from .._utils.utils import range_op, get_1d_shape, generate_shape_index, is_shape_unknown
@ -116,20 +115,10 @@ def dout_cast_row_tensor(dout, x):
@bprop_getters.register(P.Cast)
def get_bprop_cast(self):
"""Generate bprop for Cast"""
cast = P.Cast()
get_dtype = P.DType()
def bprop(x, t, out, dout):
dx = cast(dout, get_dtype(x))
return dx, zeros_like(t)
def bprop_sparse(x, t, out, dout):
dx = dout_cast(dout, x)
return dx, zeros_like(t)
if context.get_context('enable_sparse'):
return bprop_sparse
return bprop

View File

@ -9,6 +9,6 @@ y
bprop.12:x*
bprop.12:out*
bprop.12:dout2
bprop.12:[CNode]14:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.12:[CNode]14:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -9,6 +9,6 @@ z
bprop.15:x*
bprop.15:out*
bprop.15:dout2
bprop.15:[CNode]17:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
bprop.15:[CNode]17:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -9,6 +9,6 @@ z
bprop.18:x*
bprop.18:out*
bprop.18:dout2
bprop.18:[CNode]20:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.18:[CNode]20:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -9,6 +9,6 @@ z
bprop.24:x*
bprop.24:out*
bprop.24:dout2
bprop.24:[CNode]26:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh
bprop.24:[CNode]26:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -5,5 +5,5 @@ m
bprop.1:x*
bprop.1:out*
bprop.1:dout2
bprop.1:[CNode]2:1:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
bprop.1:[CNode]2:1:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
S-Prim-MakeTuple:2S-Prim-MakeTupleh

View File

@ -7,6 +7,6 @@ s
bprop.6:x*
bprop.6:out*
bprop.6:dout2
bprop.6:[CNode]8:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
bprop.6:[CNode]8:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -7,6 +7,6 @@ s
bprop.3:x*
bprop.3:out*
bprop.3:dout2
bprop.3:[CNode]5:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh
bprop.3:[CNode]5:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -9,6 +9,6 @@ z
bprop.27:x*
bprop.27:out*
bprop.27:dout2
bprop.27:[CNode]29:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.27:[CNode]29:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -27,9 +27,9 @@ bprop.30:x*
bprop.30:y*
bprop.30:out*
bprop.30:dout2
bprop.30:[CNode]36:8:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pbv
bprop.30:[CNode]36:8:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbv
S-Prim-Select:5 S-Prim-Select
output_names€Š Zoutput€3
input_names€ŠZ condition€ŠZx€ŠZy€bH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:9S-Prim-MakeTupleh
input_names€ŠZ condition€ŠZx€ŠZy€b&
S-Prim-MakeTuple:9S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -9,6 +9,6 @@ z
bprop.21:x*
bprop.21:out*
bprop.21:dout2
bprop.21:[CNode]23:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.21:[CNode]23:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -7,6 +7,6 @@ v
bprop.9:x*
bprop.9:out*
bprop.9:dout2
bprop.9:[CNode]11:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
bprop.9:[CNode]11:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -913,6 +913,7 @@ class SparseGatherV2(PrimitiveWithCheck):
def __init__(self):
"""Initialize SparseGatherV2"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr('bprop_return_sparse', True)
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
@ -6291,6 +6292,7 @@ class EmbeddingLookup(PrimitiveWithCheck):
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
outputs=['output'])
self.add_prim_attr('bprop_return_sparse', True)
def __check__(self, params, indices, offset):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)

View File

@ -244,8 +244,6 @@ class ParallelMultiHotFactory:
self.opt = Adam(params=net.get_parameters())
if self.target == 'CPU':
self.opt.target = self.target
if self.sparse:
context.set_context(enable_sparse=True)
self.model = Model(network=net,
loss_fn=self.loss_fn,
optimizer=self.opt)

View File

@ -19,8 +19,7 @@ from mindspore import Tensor, context
from mindspore.nn import TrainOneStepCell, WithLossCell
context.set_context(enable_sparse=True,
mode=context.GRAPH_MODE)
context.set_context(mode=context.GRAPH_MODE)
class NetWithEmbeddingLookUp(nn.Cell):

View File

@ -20,10 +20,10 @@ from mindspore.nn import TrainOneStepCell
from mindspore.nn.optim import FTRL, LazyAdam
from mindspore.ops import operations as P
context.set_context(enable_sparse=True,
mode=context.GRAPH_MODE,
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend")
class NetWithSparseGatherV2(nn.Cell):
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
@ -35,6 +35,7 @@ class NetWithSparseGatherV2(nn.Cell):
def construct(self, indices, label):
return self.gather(self.weight1, indices, self.axis) + self.weight2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -56,6 +57,7 @@ def test_ftrl_net():
[[0.6821311, 0.6821311]],
[[0.6821311, 0.6821311]]]))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training

View File

@ -165,7 +165,6 @@ def test_allreduce_sparsegatherv2_adam_auto_parallel():
indices = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]).astype(np.int32))
epoch = 3
batch_size = 1
context.set_context(enable_sparse=True)
net = NetWithSparseGatherV2(sparse=True)
output_sparse = net.train_mindspore_impl(indices, epoch, batch_size)
net = NetWithSparseGatherV2(sparse=False)

View File

@ -97,7 +97,7 @@ def test_resnet_imagenet_8p_mpi():
"""
epoch_size = 2
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_graph_kernel=False, enable_sparse=False)
context.set_context(enable_graph_kernel=False)
context.reset_auto_parallel_context()
context.reset_ps_context()
context.set_auto_parallel_context(device_num=8, parallel_mode=ParallelMode.DATA_PARALLEL,

View File

@ -33,7 +33,7 @@ parser.add_argument("--device_target", type=str, default="Ascend")
args, _ = parser.parse_known_args()
device_target = args.device_target
context.set_context(
mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True
mode=context.GRAPH_MODE, device_target=device_target
)
context.set_ps_context(enable_ps=True)

View File

@ -43,7 +43,7 @@ parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspo
args, _ = parser.parse_known_args()
device_target = args.device_target
dataset_path = args.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True)
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
context.set_ps_context(enable_ps=True)

View File

@ -20,8 +20,7 @@ from mindspore.nn import TrainOneStepCell
from mindspore.nn.optim import FTRL, LazyAdam
from mindspore.ops import operations as P
context.set_context(enable_sparse=True,
mode=context.PYNATIVE_MODE,
context.set_context(mode=context.PYNATIVE_MODE,
device_target="Ascend")
class NetWithSparseGatherV2(nn.Cell):

View File

@ -23,7 +23,7 @@ from mindspore.ops import composite as C
def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
context.set_context(mode=context.PYNATIVE_MODE)
class _Grad(nn.Cell):

View File

@ -42,12 +42,7 @@ from mindspore.train import Model
from ....dataset_mock import MindData
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
yield
context.set_context(enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum()
@ -115,6 +110,7 @@ class MySparseGatherV2(PrimitiveWithInfer):
def __init__(self):
"""init index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr('bprop_return_sparse', True)
def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)

View File

@ -26,12 +26,8 @@ import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore import Tensor, COOTensor, context
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
yield
context.set_context(enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
grad_op = C.GradOperation(get_all=True)

View File

@ -14,7 +14,6 @@
# ============================================================================
""" test ADA_GRAD """
import pytest
import numpy as np
import mindspore.nn as nn
@ -25,11 +24,7 @@ from mindspore.nn.optim import Adagrad
from mindspore.ops import operations as P
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(enable_sparse=True)
yield
context.set_context(enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):

View File

@ -14,7 +14,6 @@
# ============================================================================
""" test adafactor """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter, context
@ -24,11 +23,7 @@ from mindspore.nn.optim.adafactor import AdaFactor
from mindspore.ops import operations as P
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(enable_sparse=True)
yield
context.set_context(enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):

View File

@ -23,11 +23,9 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay
from mindspore.ops import operations as P
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(enable_sparse=True)
yield
context.set_context(enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
""" Net definition """

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test FTRL """
import pytest
import numpy as np
import mindspore.nn as nn
@ -23,11 +22,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import FTRL
from mindspore.ops import operations as P
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(enable_sparse=True)
yield
context.set_context(enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test PROXIMAL_ADA_GRAD """
import pytest
import numpy as np
import mindspore.nn as nn
@ -23,11 +22,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import ProximalAdagrad
from mindspore.ops import operations as P
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(enable_sparse=True)
yield
context.set_context(enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):

View File

@ -22,11 +22,8 @@ from mindspore.nn import Cell, TrainOneStepCell, LazyAdam
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(enable_sparse=True)
yield
context.set_context(enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
class Net(Cell):

View File

@ -29,9 +29,8 @@ grad_all = C.GradOperation(get_all=True)
@pytest.fixture(name="test_context")
def _test_context():
context.set_context(enable_sparse=True)
context.set_context(mode=context.GRAPH_MODE)
yield
context.set_context(enable_sparse=False)
context.reset_auto_parallel_context()

View File

@ -32,9 +32,8 @@ grad_all = C.GradOperation(get_all=True)
@pytest.fixture(name="test_context")
def _test_context():
context.set_context(enable_sparse=True)
context.set_context(mode=context.GRAPH_MODE)
yield
context.set_context(enable_sparse=False)
context.reset_auto_parallel_context()

View File

@ -26,9 +26,9 @@ from mindspore.ops import composite as C
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True)
context.set_context(mode=context.PYNATIVE_MODE)
yield
context.set_context(mode=context.GRAPH_MODE, enable_sparse=False)
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)