forked from mindspore-Ecosystem/mindspore
Add bprop_return_sparse flag for sparse bprop primitive and remove context option: 'enable_sparse'.
This commit is contained in:
parent
da2a22dc22
commit
d7762f1c8d
|
@ -59,6 +59,7 @@ using mindspore::abstract::AbstractTensor;
|
|||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractTuplePtr;
|
||||
using mindspore::abstract::AbstractUndetermined;
|
||||
using mindspore::abstract::EnvSetSparseResultMgr;
|
||||
using mindspore::abstract::FuncGraphAbstractClosure;
|
||||
|
||||
void HyperMap::Init() {
|
||||
|
@ -645,19 +646,12 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
|
|||
}
|
||||
|
||||
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
|
||||
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
|
||||
const std::vector<AnfNodePtr> &weight_args) {
|
||||
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad) {
|
||||
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
|
||||
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
k_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
|
||||
|
||||
AnfNodePtr weights_node = nullptr;
|
||||
AnfNodePtr position_node = nullptr;
|
||||
if (weights != nullptr) {
|
||||
weights_node = weights;
|
||||
} else if (!weight_args.empty()) {
|
||||
weights_node = k_child->NewCNodeInOrder(weight_args);
|
||||
}
|
||||
if (position != nullptr) {
|
||||
position_node = position;
|
||||
}
|
||||
|
@ -673,7 +667,7 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weigh
|
|||
auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
|
||||
auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
|
||||
|
||||
GradByParameter(k_child, f_app, bprop, weights_node, position_node, enable_tuple_grad);
|
||||
GradByParameter(k_child, f_app, bprop, weights, position_node, enable_tuple_grad);
|
||||
return k_child;
|
||||
}
|
||||
|
||||
|
@ -740,6 +734,38 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
|
|||
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Check if primal func graph has the primitive returned sparse result in its bprop().
|
||||
void CheckPrimBpropReturnSparse(const FuncGraphPtr &primal_graph) {
|
||||
bool has_sparse_bprop_prim = false;
|
||||
(void)TopoSort(primal_graph->return_node(), SuccDeeperSimple,
|
||||
[&has_sparse_bprop_prim](const AnfNodePtr &node) -> IncludeType {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (has_sparse_bprop_prim) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim != nullptr) {
|
||||
auto do_signature = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
|
||||
if (do_signature != nullptr) {
|
||||
prim = dyn_cast<Primitive>(do_signature->function());
|
||||
}
|
||||
bool sparse_bprop = GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE);
|
||||
if (sparse_bprop) {
|
||||
MS_LOG(DEBUG) << "prim: " << prim->ToString() << " has attr 'bprop_return_sparse'";
|
||||
has_sparse_bprop_prim = true;
|
||||
return EXCLUDE;
|
||||
}
|
||||
}
|
||||
return FOLLOW;
|
||||
});
|
||||
if (has_sparse_bprop_prim) {
|
||||
primal_graph->set_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP, true);
|
||||
EnvSetSparseResultMgr::GetInstance().Set(true);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Generate the graph.
|
||||
FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
if (args_spec_list.empty()) {
|
||||
|
@ -761,6 +787,10 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
FuncGraphPtr forward_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
|
||||
// Check if primal func graph has the primitive returned sparse result in its bprop().
|
||||
CheckPrimBpropReturnSparse(forward_graph);
|
||||
|
||||
FuncGraphPtr grad_fg = nullptr;
|
||||
{
|
||||
TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
||||
|
|
|
@ -150,8 +150,7 @@ class GradOperation : public MetaFuncGraph {
|
|||
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
||||
|
||||
FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
|
||||
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad_first,
|
||||
const std::vector<AnfNodePtr> &weight_args = {});
|
||||
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad_first);
|
||||
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
|
||||
|
|
|
@ -353,6 +353,10 @@ AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePt
|
|||
bool result = false;
|
||||
|
||||
if (!CheckCmpValid(cmp)) {
|
||||
auto cmp_type = cmp->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(cmp_type);
|
||||
MS_LOG(ERROR) << "cmp: " << cmp->ToString() << ", cmp_type: " << cmp_type->ToString()
|
||||
<< ", cmp_type_id: " << TypeIdToType(cmp_type->type_id());
|
||||
MS_EXCEPTION(TypeError) << "isinstance() arg 2 must be a type or tuple of types.";
|
||||
}
|
||||
|
||||
|
|
|
@ -161,9 +161,7 @@ Status GatherInfo::GetAttrs() {
|
|||
dynamic_shape_indices_ = true;
|
||||
}
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) {
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
dynamic_shape_indices_ = true;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -602,16 +602,19 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|||
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
|
||||
AbstractBasePtrList bparams;
|
||||
bparams.push_back(SensitivityTransform(primal_func_));
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
// Check if primal func graph has the primitive returned sparse result in its bprop().
|
||||
auto real_primal_func = dyn_cast<FuncGraphAbstractClosure>(primal_func_);
|
||||
MS_EXCEPTION_IF_NULL(real_primal_func);
|
||||
FuncGraphPtr primal_func_graph = real_primal_func->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(primal_func_graph);
|
||||
bool has_sparse_bprop_prim = primal_func_graph->has_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP);
|
||||
(void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(bparams),
|
||||
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg_spec);
|
||||
if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
|
||||
[&has_sparse_bprop_prim](const AbstractBasePtr &arg_abs) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg_abs);
|
||||
if (has_sparse_bprop_prim && arg_abs->isa<AbstractTensor>()) {
|
||||
return std::make_shared<AbstractUndetermined>();
|
||||
}
|
||||
return SensitivityTransform(arg_spec);
|
||||
return SensitivityTransform(arg_abs);
|
||||
});
|
||||
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
|
||||
AbstractFunctionPtr bprop =
|
||||
|
|
|
@ -66,21 +66,14 @@ class Evaluator : public Base {
|
|||
}
|
||||
|
||||
virtual EvalResultPtr EvalUndeterminedArgs(const AbstractBasePtrList &args_abs_list) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (!enable_sparse) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto is_abstract = std::any_of(args_abs_list.begin(), args_abs_list.end(), [](auto &arg) -> bool {
|
||||
auto is_undetermined = std::any_of(args_abs_list.begin(), args_abs_list.end(), [](auto &arg) -> bool {
|
||||
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (is_abstract) {
|
||||
MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result";
|
||||
if (is_undetermined) {
|
||||
MS_LOG(DEBUG) << "Eval " << identifier_ << " return undetermined abstract result";
|
||||
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
@ -876,13 +876,6 @@ AbstractBasePtr MakePyInferRes2Abstract(const py::object &output) {
|
|||
// Return monad abstract if it is monad type.
|
||||
return ToMonadAbstract(type_obj);
|
||||
} else {
|
||||
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (enable_sparse) {
|
||||
return std::make_shared<abstract::AbstractUndetermined>();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -381,8 +381,10 @@ void AnalysisEngine::ClearEvaluatorCache() {
|
|||
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
|
||||
evaluator->evaluator_cache_mgr()->Clear();
|
||||
}
|
||||
// Release Exception to avoid hup at exit.
|
||||
// Release exception to avoid hup at exit.
|
||||
StaticAnalysisException::Instance().ClearException();
|
||||
// Reset the EnvironGet sparse option.
|
||||
EnvSetSparseResultMgr::GetInstance().Set(false);
|
||||
}
|
||||
|
||||
void AnalysisEngine::Clear() {
|
||||
|
|
|
@ -78,7 +78,6 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
|||
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
|
||||
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
|
||||
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
|
||||
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
|
||||
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
|
||||
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
|
||||
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
|
||||
|
|
|
@ -1909,9 +1909,7 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph &g
|
|||
} else if (cnode->isa<CNode>() && (common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName)) {
|
||||
MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
|
||||
FinalizePsCache(
|
||||
"All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
|
||||
"the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "
|
||||
"context setting in parameter server training mode.");
|
||||
"All EmbeddingLookup kernels whose input indices are from dataset must enable cache at the same time.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "ir/value.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -117,10 +117,7 @@ AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePt
|
|||
}
|
||||
|
||||
MS_LOG(DEBUG) << "key: " << key->ToString() << ", value: " << default_value->ToString();
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (enable_sparse && default_value->isa<AbstractTensor>()) {
|
||||
if (default_value->isa<AbstractTensor>() && EnvSetSparseResultMgr::GetInstance().Get()) {
|
||||
auto tensor_value = default_value->cast<AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_value);
|
||||
return std::make_shared<AbstractUndetermined>(tensor_value->element()->Clone(), tensor_value->shape()->Clone());
|
||||
|
@ -149,7 +146,7 @@ AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePt
|
|||
// args: Three objects of a subclass of AbstractBase, env, key, value.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, kSizeThree);
|
||||
|
||||
auto key = args_spec_list[1];
|
||||
auto key = args_spec_list[kIndexOne];
|
||||
ValuePtr key_value_ptr = key->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(key_value_ptr);
|
||||
auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
|
||||
|
@ -160,7 +157,11 @@ AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePt
|
|||
auto expected = key_value_track->abstract();
|
||||
MS_EXCEPTION_IF_NULL(expected);
|
||||
|
||||
MS_LOG(DEBUG) << "key: " << key->ToString() << ", value: " << args_spec_list[kIndexTwo]->ToString();
|
||||
auto value = args_spec_list[kIndexTwo];
|
||||
MS_LOG(DEBUG) << "key: " << key->ToString() << ", value: " << value->ToString();
|
||||
if (value->isa<AbstractUndetermined>() && !value->isa<AbstractTensor>()) {
|
||||
EnvSetSparseResultMgr::GetInstance().Set(true);
|
||||
}
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
}
|
||||
|
||||
|
|
|
@ -307,14 +307,7 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
|
|||
// Return monad abstract if it is monad type.
|
||||
return MakeMonadAbstract(type->cast<MonadTypePtr>());
|
||||
} else {
|
||||
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (enable_sparse) {
|
||||
return std::make_shared<abstract::AbstractUndetermined>();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
|
||||
MS_LOG(EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
|
||||
}
|
||||
}
|
||||
} // namespace abstract
|
||||
|
|
|
@ -58,6 +58,24 @@ void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVec
|
|||
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
|
||||
MS_CORE_API AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
|
||||
MS_CORE_API AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
|
||||
|
||||
class MS_CORE_API EnvSetSparseResultMgr {
|
||||
public:
|
||||
static EnvSetSparseResultMgr &GetInstance() noexcept {
|
||||
static EnvSetSparseResultMgr instance;
|
||||
return instance;
|
||||
}
|
||||
EnvSetSparseResultMgr(const EnvSetSparseResultMgr &) = delete;
|
||||
EnvSetSparseResultMgr &operator=(const EnvSetSparseResultMgr &) = delete;
|
||||
~EnvSetSparseResultMgr() = default;
|
||||
|
||||
bool Get() const { return env_set_sparse_result_; }
|
||||
void Set(bool env_set_sparse_result) { env_set_sparse_result_ = env_set_sparse_result; }
|
||||
|
||||
private:
|
||||
EnvSetSparseResultMgr() = default;
|
||||
bool env_set_sparse_result_{false};
|
||||
};
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_
|
||||
|
|
|
@ -82,6 +82,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
|
|||
|
||||
const char FUNC_GRAPH_FLAG_IGNORE_VALUE[] = "ignore_value";
|
||||
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
|
||||
const char FUNC_GRAPH_FLAG_SPARSE_BPROP[] = "sparse_bprop";
|
||||
const char FUNC_GRAPH_FLAG_NO_INLINE[] = "no_inline";
|
||||
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
|
||||
const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
||||
|
|
|
@ -163,8 +163,8 @@ std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root) {
|
|||
return todo;
|
||||
}
|
||||
|
||||
// PushSuccessors push cnode inputs to a vector as successors for topo sort.
|
||||
static void PushSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
|
||||
// To get CNode inputs to a vector as successors for TopoSort().
|
||||
static void FetchCNodeSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
|
||||
auto &inputs = cnode->inputs();
|
||||
vecs->reserve(vecs->size() + inputs.size());
|
||||
|
||||
|
@ -194,7 +194,7 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
|||
return vecs;
|
||||
} else if (node->func_graph() != nullptr) {
|
||||
if (node->isa<CNode>()) {
|
||||
PushSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
}
|
||||
return vecs;
|
||||
}
|
||||
|
@ -217,7 +217,7 @@ std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node) {
|
|||
return vecs;
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
PushSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
}
|
||||
return vecs;
|
||||
}
|
||||
|
@ -227,7 +227,7 @@ std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node) {
|
|||
std::vector<AnfNodePtr> vecs;
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode != nullptr) {
|
||||
PushSuccessors(cnode, &vecs);
|
||||
FetchCNodeSuccessors(cnode, &vecs);
|
||||
}
|
||||
return vecs;
|
||||
}
|
||||
|
@ -251,7 +251,7 @@ std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &
|
|||
}
|
||||
}
|
||||
}
|
||||
PushSuccessors(cnode, &vecs);
|
||||
FetchCNodeSuccessors(cnode, &vecs);
|
||||
}
|
||||
return vecs;
|
||||
}
|
||||
|
@ -275,7 +275,7 @@ std::vector<AnfNodePtr> SuccWithFilter(const GraphFilterFunc &graph_filter, cons
|
|||
return vecs;
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
PushSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
}
|
||||
return vecs;
|
||||
}
|
||||
|
|
|
@ -28,13 +28,6 @@ abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
|
|||
}
|
||||
|
||||
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (!enable_sparse) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
ParameterPtr undetermined_param = nullptr;
|
||||
auto stub = std::make_shared<FuncGraph>();
|
||||
|
|
|
@ -31,6 +31,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop";
|
|||
inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result";
|
||||
inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header";
|
||||
inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip";
|
||||
inline const char GRAPH_FLAG_BPROP_RETURN_SPARSE[] = "bprop_return_sparse";
|
||||
|
||||
// method names of python primitive called from c++ source code
|
||||
// 1. infer method name of class 'PrimitiveWithInfer'
|
||||
|
|
|
@ -92,7 +92,6 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE, kDefaultMempoolBlockSize);
|
||||
set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
|
||||
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
|
||||
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
|
||||
|
|
|
@ -77,7 +77,6 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_ENABLE_PYNATIVE_HOOK,
|
||||
MS_CTX_ENABLE_PYNATIVE_INFER,
|
||||
MS_CTX_ENABLE_REDUCE_PRECISION,
|
||||
MS_CTX_ENABLE_SPARSE,
|
||||
MS_CTX_ENABLE_TASK_SINK,
|
||||
MS_CTX_IR_FUSION_FLAG,
|
||||
MS_CTX_IS_MULTI_GRAPH_SINK,
|
||||
|
|
|
@ -642,9 +642,9 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||
enable_auto_mixed_precision=bool,
|
||||
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
|
||||
max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int,
|
||||
env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int,
|
||||
load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
|
||||
max_device_memory=str, print_file_path=str, max_call_depth=int, env_config_path=str,
|
||||
graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool,
|
||||
grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Set context for running environment.
|
||||
|
@ -704,8 +704,6 @@ def set_context(**kwargs):
|
|||
| +------------------------------+----------------------------+
|
||||
| | max_call_depth | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | enable_sparse | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | grad_for_scalar | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | enable_compile_cache | CPU/GPU/Ascend |
|
||||
|
@ -827,9 +825,6 @@ def set_context(**kwargs):
|
|||
The max_call_depth parameter needs to be set when the nested call is too deep or the number
|
||||
of subgraphs is too large. If max_call_depth is set larger than before, the system max stack depth should be
|
||||
set larger too, otherwise a `core dumped` exception may be raised because of system stack overflow.
|
||||
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
|
||||
For details of sparsity and sparse tensor, please check
|
||||
`sparse tensor <https://www.mindspore.cn/tutorials/en/master/beginner/tensor.html#sparse-tensor>`_.
|
||||
grad_for_scalar (bool): Whether to get gradient for scalar. Default: False.
|
||||
When grad_for_scalar is set to True, the function's scalar input can be derived.
|
||||
The default value is False. Because the back-end does not support scaling operations currently,
|
||||
|
@ -867,7 +862,6 @@ def set_context(**kwargs):
|
|||
>>> set_context(max_device_memory="3.5GB")
|
||||
>>> set_context(mempool_block_size="1GB")
|
||||
>>> set_context(print_file_path="print.pb")
|
||||
>>> set_context(enable_sparse=True)
|
||||
>>> set_context(max_call_depth=80)
|
||||
>>> set_context(env_config_path="./env_config.json")
|
||||
>>> set_context(auto_tune_mode="GA,RL")
|
||||
|
|
|
@ -25,7 +25,6 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
||||
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
|
||||
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id
|
||||
from mindspore import context
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -339,12 +338,6 @@ class EmbeddingLookup(Cell):
|
|||
self.cache_enable = True
|
||||
if _is_role_worker():
|
||||
self.vocab_size = self.vocab_cache_size
|
||||
if context.get_context("enable_sparse") != self.sparse:
|
||||
raise ValueError(f"For '{self.cls_name}', the value of parameter 'sparse' must be same for all "
|
||||
f"kernels and equal the value of 'enable_sparse' in context setting in "
|
||||
f"parameter server cache mode, but got value of parameter 'sparse': {self.sparse}"
|
||||
f" and the 'enable_sparse' in context setting: "
|
||||
f"{context.get_context('enable_sparse')}.")
|
||||
|
||||
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
|
||||
"""PS embeddingLookup cache enable set."""
|
||||
|
|
|
@ -898,10 +898,6 @@ class EmbeddingLookupThor(Cell):
|
|||
self.cache_enable = True
|
||||
if _is_role_worker():
|
||||
self.vocab_size = self.vocab_cache_size
|
||||
if context.get_context("enable_sparse") != self.sparse:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'sparse' must be equal to the 'enable_sparse' "
|
||||
f"in context setting in parameter server cache mode, but got 'sparse': "
|
||||
f"{self.sparse}, 'enable_sparse': {context.get_context('enable_sparse')}.")
|
||||
|
||||
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
|
||||
"""PS embeddingLookup cache enable set."""
|
||||
|
|
|
@ -26,7 +26,6 @@ from ..functional import broadcast_gradient_args
|
|||
from .. import functional as F
|
||||
from .grad_base import bprop_getters
|
||||
from ..primitive import constexpr
|
||||
from ... import context
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import RowTensor
|
||||
from .._utils.utils import range_op, get_1d_shape, generate_shape_index, is_shape_unknown
|
||||
|
@ -116,20 +115,10 @@ def dout_cast_row_tensor(dout, x):
|
|||
@bprop_getters.register(P.Cast)
|
||||
def get_bprop_cast(self):
|
||||
"""Generate bprop for Cast"""
|
||||
cast = P.Cast()
|
||||
get_dtype = P.DType()
|
||||
|
||||
def bprop(x, t, out, dout):
|
||||
dx = cast(dout, get_dtype(x))
|
||||
return dx, zeros_like(t)
|
||||
|
||||
def bprop_sparse(x, t, out, dout):
|
||||
dx = dout_cast(dout, x)
|
||||
return dx, zeros_like(t)
|
||||
|
||||
if context.get_context('enable_sparse'):
|
||||
return bprop_sparse
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,6 @@ y
|
|||
bprop.12:x*
|
||||
bprop.12:out*
|
||||
bprop.12:dout2
|
||||
bprop.12:[CNode]14:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.12:[CNode]14:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.15:x*
|
||||
bprop.15:out*
|
||||
bprop.15:dout2
|
||||
bprop.15:[CNode]17:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
bprop.15:[CNode]17:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.18:x*
|
||||
bprop.18:out*
|
||||
bprop.18:dout2
|
||||
bprop.18:[CNode]20:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.18:[CNode]20:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.24:x*
|
||||
bprop.24:out*
|
||||
bprop.24:dout2
|
||||
bprop.24:[CNode]26:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
||||
bprop.24:[CNode]26:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -5,5 +5,5 @@ m
|
|||
bprop.1:x*
|
||||
bprop.1:out*
|
||||
bprop.1:dout2
|
||||
bprop.1:[CNode]2:1:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
bprop.1:[CNode]2:1:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
|
||||
S-Prim-MakeTuple:2S-Prim-MakeTupleh
|
|
@ -7,6 +7,6 @@ s
|
|||
bprop.6:x*
|
||||
bprop.6:out*
|
||||
bprop.6:dout2
|
||||
bprop.6:[CNode]8:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
|
||||
bprop.6:[CNode]8:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -7,6 +7,6 @@ s
|
|||
bprop.3:x*
|
||||
bprop.3:out*
|
||||
bprop.3:dout2
|
||||
bprop.3:[CNode]5:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
||||
bprop.3:[CNode]5:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.27:x*
|
||||
bprop.27:out*
|
||||
bprop.27:dout2
|
||||
bprop.27:[CNode]29:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.27:[CNode]29:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
Binary file not shown.
Binary file not shown.
|
@ -27,9 +27,9 @@ bprop.30:x*
|
|||
bprop.30:y*
|
||||
bprop.30:out*
|
||||
bprop.30:dout2
|
||||
bprop.30:[CNode]36:8:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pbv
|
||||
bprop.30:[CNode]36:8:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbv
|
||||
S-Prim-Select:5
S-Prim-Select
|
||||
output_names€ŠZoutput€3
|
||||
input_names€ŠZ condition€ŠZx€ŠZy€bH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:9S-Prim-MakeTupleh
|
||||
input_names€ŠZ condition€ŠZx€ŠZy€b&
|
||||
S-Prim-MakeTuple:9S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.21:x*
|
||||
bprop.21:out*
|
||||
bprop.21:dout2
|
||||
bprop.21:[CNode]23:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.21:[CNode]23:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -7,6 +7,6 @@ v
|
|||
bprop.9:x*
|
||||
bprop.9:out*
|
||||
bprop.9:dout2
|
||||
bprop.9:[CNode]11:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
bprop.9:[CNode]11:3:@231212aa03e0343893c3476fd4c71a316576977db00252cc9713d543de78d44cPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -913,6 +913,7 @@ class SparseGatherV2(PrimitiveWithCheck):
|
|||
def __init__(self):
|
||||
"""Initialize SparseGatherV2"""
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
self.add_prim_attr('bprop_return_sparse', True)
|
||||
|
||||
def __check__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
|
@ -6291,6 +6292,7 @@ class EmbeddingLookup(PrimitiveWithCheck):
|
|||
self.__setattr_flag__ = True
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
|
||||
outputs=['output'])
|
||||
self.add_prim_attr('bprop_return_sparse', True)
|
||||
|
||||
def __check__(self, params, indices, offset):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
|
|
|
@ -244,8 +244,6 @@ class ParallelMultiHotFactory:
|
|||
self.opt = Adam(params=net.get_parameters())
|
||||
if self.target == 'CPU':
|
||||
self.opt.target = self.target
|
||||
if self.sparse:
|
||||
context.set_context(enable_sparse=True)
|
||||
self.model = Model(network=net,
|
||||
loss_fn=self.loss_fn,
|
||||
optimizer=self.opt)
|
||||
|
|
|
@ -19,8 +19,7 @@ from mindspore import Tensor, context
|
|||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
|
||||
|
||||
context.set_context(enable_sparse=True,
|
||||
mode=context.GRAPH_MODE)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class NetWithEmbeddingLookUp(nn.Cell):
|
||||
|
|
|
@ -20,10 +20,10 @@ from mindspore.nn import TrainOneStepCell
|
|||
from mindspore.nn.optim import FTRL, LazyAdam
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(enable_sparse=True,
|
||||
mode=context.GRAPH_MODE,
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend")
|
||||
|
||||
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
|
@ -35,6 +35,7 @@ class NetWithSparseGatherV2(nn.Cell):
|
|||
def construct(self, indices, label):
|
||||
return self.gather(self.weight1, indices, self.axis) + self.weight2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -56,6 +57,7 @@ def test_ftrl_net():
|
|||
[[0.6821311, 0.6821311]],
|
||||
[[0.6821311, 0.6821311]]]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
|
@ -165,7 +165,6 @@ def test_allreduce_sparsegatherv2_adam_auto_parallel():
|
|||
indices = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]).astype(np.int32))
|
||||
epoch = 3
|
||||
batch_size = 1
|
||||
context.set_context(enable_sparse=True)
|
||||
net = NetWithSparseGatherV2(sparse=True)
|
||||
output_sparse = net.train_mindspore_impl(indices, epoch, batch_size)
|
||||
net = NetWithSparseGatherV2(sparse=False)
|
||||
|
|
|
@ -97,7 +97,7 @@ def test_resnet_imagenet_8p_mpi():
|
|||
"""
|
||||
epoch_size = 2
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(enable_graph_kernel=False, enable_sparse=False)
|
||||
context.set_context(enable_graph_kernel=False)
|
||||
context.reset_auto_parallel_context()
|
||||
context.reset_ps_context()
|
||||
context.set_auto_parallel_context(device_num=8, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
|
|
|
@ -33,7 +33,7 @@ parser.add_argument("--device_target", type=str, default="Ascend")
|
|||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True
|
||||
mode=context.GRAPH_MODE, device_target=device_target
|
||||
)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspo
|
|||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
dataset_path = args.dataset_path
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
|
||||
|
|
|
@ -20,8 +20,7 @@ from mindspore.nn import TrainOneStepCell
|
|||
from mindspore.nn.optim import FTRL, LazyAdam
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(enable_sparse=True,
|
||||
mode=context.PYNATIVE_MODE,
|
||||
context.set_context(mode=context.PYNATIVE_MODE,
|
||||
device_target="Ascend")
|
||||
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import composite as C
|
|||
|
||||
|
||||
def setup_module():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class _Grad(nn.Cell):
|
||||
|
|
|
@ -42,12 +42,7 @@ from mindspore.train import Model
|
|||
from ....dataset_mock import MindData
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
reduce_sum = P.ReduceSum()
|
||||
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||
|
@ -115,6 +110,7 @@ class MySparseGatherV2(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
"""init index_select"""
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
self.add_prim_attr('bprop_return_sparse', True)
|
||||
|
||||
def __infer__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
|
|
|
@ -26,12 +26,8 @@ import mindspore.nn as nn
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore import Tensor, COOTensor, context
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
grad_op = C.GradOperation(get_all=True)
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
""" test ADA_GRAD """
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -25,11 +24,7 @@ from mindspore.nn.optim import Adagrad
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(enable_sparse=True)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
""" test adafactor """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter, context
|
||||
|
@ -24,11 +23,7 @@ from mindspore.nn.optim.adafactor import AdaFactor
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(enable_sparse=True)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
|
@ -23,11 +23,9 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|||
from mindspore.nn.optim import Adam, AdamWeightDecay
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(enable_sparse=True)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test FTRL """
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -23,11 +22,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|||
from mindspore.nn.optim import FTRL
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(enable_sparse=True)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test PROXIMAL_ADA_GRAD """
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -23,11 +22,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|||
from mindspore.nn.optim import ProximalAdagrad
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(enable_sparse=True)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
|
@ -22,11 +22,8 @@ from mindspore.nn import Cell, TrainOneStepCell, LazyAdam
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(enable_sparse=True)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
|
|
|
@ -29,9 +29,8 @@ grad_all = C.GradOperation(get_all=True)
|
|||
|
||||
@pytest.fixture(name="test_context")
|
||||
def _test_context():
|
||||
context.set_context(enable_sparse=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
|
|
|
@ -32,9 +32,8 @@ grad_all = C.GradOperation(get_all=True)
|
|||
|
||||
@pytest.fixture(name="test_context")
|
||||
def _test_context():
|
||||
context.set_context(enable_sparse=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
yield
|
||||
context.set_context(enable_sparse=False)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
|
|
|
@ -26,9 +26,9 @@ from mindspore.ops import composite as C
|
|||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
yield
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=False)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
|
Loading…
Reference in New Issue