forked from mindspore-Ecosystem/mindspore
!27461 Add shard operation
Merge pull request !27461 from YuJianfeng/shard
This commit is contained in:
commit
d1d516e668
|
@ -1115,5 +1115,76 @@ REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
|
|||
*m, "TupleGetItemTensor_")
|
||||
.def(py::init<std::string &>());
|
||||
}));
|
||||
|
||||
namespace {
|
||||
FuncGraphPtr GetShard(const AnfNodePtr &shard, const std::vector<AnfNodePtr> &origin_graph_params) {
|
||||
FuncGraphPtr shard_child = std::make_shared<FuncGraph>();
|
||||
shard_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(origin_graph_params.size() + 1);
|
||||
(void)inputs.emplace_back(shard);
|
||||
for (size_t i = 0; i < origin_graph_params.size(); ++i) {
|
||||
(void)inputs.emplace_back(shard_child->add_parameter());
|
||||
}
|
||||
auto shard_app = shard_child->NewCNodeInOrder(std::move(inputs));
|
||||
|
||||
shard_child->set_output(shard_app);
|
||||
return shard_child;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FuncGraphPtr Shard::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
constexpr size_t shard_input_size = 5;
|
||||
if (args_spec_list.size() != shard_input_size) {
|
||||
MS_LOG(EXCEPTION) << "'Shard' requires " << shard_input_size
|
||||
<< " inputs. Includes a Cell or function, in_axes, out_axes, device and level.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
|
||||
if (fn == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'Shard' arg0 must be a 'Function' or 'Cell', but got " << args_spec_list[0]->ToString()
|
||||
<< ".";
|
||||
}
|
||||
|
||||
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
|
||||
MS_EXCEPTION_IF_NULL(real_fn);
|
||||
FuncGraphPtr origin_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(origin_graph);
|
||||
origin_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
FuncGraphPtr shard_fg = nullptr;
|
||||
{
|
||||
TraceGuard g(std::make_shared<TraceShard>(origin_graph->debug_info()));
|
||||
shard_fg = std::make_shared<FuncGraph>();
|
||||
}
|
||||
// Create the debug info
|
||||
auto parameter_size = origin_graph->parameters().size();
|
||||
std::ostringstream ss;
|
||||
ss << "shard{" << parameter_size << "}";
|
||||
shard_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
shard_fg->debug_info()->set_name(ss.str());
|
||||
// Make the Shard node.
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(args_spec_list.size() + 1);
|
||||
(void)inputs.emplace_back(NewValueNode(prim::kPrimShard));
|
||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||
(void)inputs.emplace_back(shard_fg->add_parameter());
|
||||
}
|
||||
auto shard = shard_fg->NewCNodeInOrder(std::move(inputs));
|
||||
|
||||
FuncGraphPtr shard_child = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceShard>(shard_fg->debug_info()));
|
||||
shard_child = GetShard(shard, origin_graph->parameters());
|
||||
}
|
||||
shard_fg->set_output(NewValueNode(shard_child));
|
||||
return shard_fg;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(Shard_, ([](const py::module *m) {
|
||||
(void)py::class_<Shard, MetaFuncGraph, std::shared_ptr<Shard>>(*m, "Shard_")
|
||||
.def(py::init<std::string &>(), py::arg("fn"));
|
||||
}));
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -213,6 +213,23 @@ class TupleGetItemTensor : public MetaFuncGraph {
|
|||
}
|
||||
};
|
||||
using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
|
||||
|
||||
class Shard : public MetaFuncGraph {
|
||||
public:
|
||||
explicit Shard(const string &name) : MetaFuncGraph(name) {
|
||||
signatures_ =
|
||||
// def shard(func:read, weight_list:read, in_axes:read, out_axes:read, device:read, level:read):
|
||||
std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
|
||||
{"in_axes", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
|
||||
{"out_axes", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
|
||||
{"device", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
|
||||
{"level", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}});
|
||||
}
|
||||
~Shard() override = default;
|
||||
MS_DECLARE_PARENT(Shard, MetaFuncGraph)
|
||||
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
};
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -714,6 +714,26 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
|
|||
return AbstractFunction::MakeAbstractFunction(jv);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: func, in_axes, out_axes, device, level.
|
||||
constexpr size_t shard_input_size = 5;
|
||||
CheckArgsSize(primitive->name(), args_spec_list, shard_input_size);
|
||||
MS_LOG(DEBUG) << "Evaluate Shard: " << args_spec_list[0]->ToString();
|
||||
|
||||
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
||||
AbstractFuncAtomPtrList shard_v;
|
||||
auto build_shard_v = [&shard_v](const AbstractFuncAtomPtr &func) {
|
||||
auto shard_closure = std::make_shared<ShardTransformedAbstractClosure>(func);
|
||||
shard_v.push_back(shard_closure);
|
||||
};
|
||||
x->Visit(build_shard_v);
|
||||
|
||||
return AbstractFunction::MakeAbstractFunction(shard_v);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
|
@ -779,6 +799,7 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferI
|
|||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
||||
InferImplBroadcastGradientArgs, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeSiice, prim::kPrimMakeSlice, InferImplMakeSlice, nullptr);
|
||||
|
|
|
@ -55,6 +55,8 @@ AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -523,6 +523,28 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr ShardEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
|
||||
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
|
||||
if (eval_result != nullptr) {
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
// Call the original evaluator, get the result: y = f(x)
|
||||
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
auto res = std::make_shared<EvalResult>(result->abstract(), std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
if (args_spec_list.size() != args_spec_list_.size()) {
|
||||
|
|
|
@ -357,6 +357,41 @@ class JEvaluator : public Evaluator {
|
|||
AbstractFunctionPtr orig_func_;
|
||||
};
|
||||
|
||||
class ShardEvaluator : public Evaluator {
|
||||
public:
|
||||
ShardEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
|
||||
: Evaluator("ShardEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {}
|
||||
~ShardEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(ShardEvaluator, Evaluator);
|
||||
|
||||
AnfNodePtr bound_node() const override {
|
||||
if (evaluator_ != nullptr) {
|
||||
return evaluator_->bound_node();
|
||||
}
|
||||
return bound_node_.lock();
|
||||
}
|
||||
|
||||
void set_bound_node(const AnfNodePtr &node) override {
|
||||
if (evaluator_ != nullptr) {
|
||||
evaluator_->set_bound_node(node);
|
||||
}
|
||||
bound_node_ = AnfNodeWeakPtr(node);
|
||||
}
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
EvaluatorPtr evaluator_;
|
||||
AbstractFunctionPtr orig_func_;
|
||||
};
|
||||
|
||||
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -458,6 +458,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformed
|
|||
return jevaluator;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
AbstractFunctionPtr func_orig = func->fn();
|
||||
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
|
||||
auto shard_evaluator = std::make_shared<ShardEvaluator>(evaluator_orig, func_orig);
|
||||
return shard_evaluator;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
std::shared_ptr<VirtualEvaluator> virtual_evaluator =
|
||||
|
@ -495,6 +503,8 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
return _GetEvaluatorFor(func->cast<std::shared_ptr<MetaFuncGraphAbstractClosure>>());
|
||||
} else if (func->isa<JTransformedAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
|
||||
} else if (func->isa<ShardTransformedAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<ShardTransformedAbstractClosure>>());
|
||||
} else if (func->isa<VirtualAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<VirtualAbstractClosure>>());
|
||||
} else if (func->isa<PartialAbstractClosure>()) {
|
||||
|
|
|
@ -289,6 +289,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &fn);
|
||||
|
||||
FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; }
|
||||
const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; }
|
||||
|
|
|
@ -254,6 +254,20 @@ std::size_t JTransformedAbstractClosure::hash() const {
|
|||
return hash_value;
|
||||
}
|
||||
|
||||
bool ShardTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<ShardTransformedAbstractClosure>()) {
|
||||
return false;
|
||||
}
|
||||
auto other_transformed = static_cast<const ShardTransformedAbstractClosure *>(&other);
|
||||
return fn_ == other_transformed->fn_;
|
||||
}
|
||||
|
||||
std::size_t ShardTransformedAbstractClosure::hash() const {
|
||||
MS_EXCEPTION_IF_NULL(fn_);
|
||||
auto hash_value = hash_combine(tid(), fn_->hash());
|
||||
return hash_value;
|
||||
}
|
||||
|
||||
bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<VirtualAbstractClosure>()) {
|
||||
return false;
|
||||
|
|
|
@ -322,6 +322,36 @@ class MS_CORE_API JTransformedAbstractClosure final : public AbstractFuncAtom {
|
|||
AbstractFuncAtomPtr fn_;
|
||||
};
|
||||
|
||||
/// \brief ShardTransformedAbstractClosure defines interface for abstract of Function
|
||||
/// transformed through the application of Shard.
|
||||
class MS_CORE_API ShardTransformedAbstractClosure final : public AbstractFuncAtom {
|
||||
public:
|
||||
/// \brief Constructor of ShardTransformedAbstractClosure
|
||||
///
|
||||
/// \param[in] fn The AbstractFuncAtom transformed through the application of Shard.
|
||||
explicit ShardTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
|
||||
|
||||
/// \brief Destructor of ShardTransformedAbstractClosure
|
||||
~ShardTransformedAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(ShardTransformedAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
/// \brief Get the AbstractFuncAtom ShardTransformedAbstractClosure corresponding to.
|
||||
///
|
||||
/// \return The AbstractFuncAtom ShardTransformedAbstractClosure corresponding to.
|
||||
AbstractFuncAtomPtr fn() { return fn_; }
|
||||
|
||||
AbstractFunctionPtr Copy() const override { return std::make_shared<ShardTransformedAbstractClosure>(fn_); }
|
||||
|
||||
bool operator==(const AbstractFunction &other) const override;
|
||||
|
||||
std::size_t hash() const override;
|
||||
|
||||
std::string ToString() const override { return "Shard(" + fn_->ToString() + ")"; }
|
||||
|
||||
private:
|
||||
AbstractFuncAtomPtr fn_;
|
||||
};
|
||||
|
||||
/// \brief VirtualAbstractClosure defines interface for function with an explicitly
|
||||
/// fixed type signature.
|
||||
class MS_CORE_API VirtualAbstractClosure final : public AbstractFuncAtom {
|
||||
|
|
|
@ -686,6 +686,7 @@ inline const PrimitivePtr kPrimPyInterpret = std::make_shared<Primitive>("PyInte
|
|||
// Other primitive not used by backend but used in core;
|
||||
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J", kSideEffectPropagate);
|
||||
inline const PrimitivePtr kPrimShard = std::make_shared<Primitive>("Shard", kSideEffectPropagate);
|
||||
|
||||
// Used to build graph which have keyword arguments
|
||||
inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");
|
||||
|
|
|
@ -402,6 +402,14 @@ class TraceMixedPrecision : public TraceInfo {
|
|||
~TraceMixedPrecision() override = default;
|
||||
TraceInfoPtr clone() override { return std::make_shared<TraceMixedPrecision>(*this); }
|
||||
};
|
||||
|
||||
class TraceShard : public TraceInfo {
|
||||
public:
|
||||
explicit TraceShard(const DebugInfoPtr &info) : TraceInfo(info) {}
|
||||
~TraceShard() override = default;
|
||||
std::string name() const override { return "shard_ops"; }
|
||||
TraceInfoPtr clone() override { return std::make_shared<TraceShard>(*this); }
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_
|
||||
|
|
|
@ -21,7 +21,7 @@ Pre-defined combination of operators.
|
|||
|
||||
|
||||
from .base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
|
||||
core, env_get, tail, zip_operation
|
||||
core, env_get, tail, zip_operation, Shard
|
||||
from .clip_ops import clip_by_value, clip_by_global_norm
|
||||
from .multitype_ops.add_impl import hyper_add
|
||||
from .multitype_ops.ones_like_impl import ones_like
|
||||
|
@ -60,4 +60,5 @@ __all__ = [
|
|||
'repeat_elements',
|
||||
'sequence_mask',
|
||||
'matmul',
|
||||
'_Grad']
|
||||
'_Grad',
|
||||
'Shard']
|
||||
|
|
|
@ -20,7 +20,7 @@ from functools import partial
|
|||
from types import FunctionType
|
||||
|
||||
from mindspore import context
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_executor, _wrap_func
|
||||
|
@ -735,6 +735,48 @@ class Map(Map_):
|
|||
return tuple(map(func, *args_list))
|
||||
|
||||
|
||||
class Shard(Shard_):
|
||||
"""Shard operation"""
|
||||
def __init__(self):
|
||||
"""Initialize Shard."""
|
||||
Shard_.__init__(self, 'Shard')
|
||||
self.shard_fn = None
|
||||
self.fn = None
|
||||
self.in_axes = None
|
||||
self.out_axes = None
|
||||
self.device = None
|
||||
self.level = None
|
||||
|
||||
def __call__(self, fn, in_axes, out_axes, device, level=0):
|
||||
if not isinstance(in_axes, tuple):
|
||||
raise TypeError(f"For 'Shard', the 'in_axes' should be a tuple, but got {type(in_axes).__name__}")
|
||||
if not isinstance(out_axes, tuple):
|
||||
raise TypeError(f"For 'Shard', the 'out_axes' should be a tuple, "
|
||||
f"but got {type(out_axes).__name__}")
|
||||
if not isinstance(device, str):
|
||||
raise TypeError(f"For 'Shard', the 'device' should be a string, "
|
||||
f"but got {type(device).__name__}")
|
||||
if not isinstance(level, int):
|
||||
raise TypeError(f"For 'Shard', the 'level' should be an integer, "
|
||||
f"but got {type(level).__name__}")
|
||||
if self.shard_fn is not None and self.fn == fn and self.in_axes == in_axes and self.out_axes == out_axes and \
|
||||
self.device == device and self.level == level:
|
||||
return self.shard_fn
|
||||
shard_ = Shard()
|
||||
|
||||
@ms_function
|
||||
def after_shard(*args):
|
||||
return shard_(fn, in_axes, out_axes, device, level)(*args)
|
||||
|
||||
self.shard_fn = after_shard
|
||||
self.fn = fn
|
||||
self.in_axes = in_axes
|
||||
self.out_axes = out_axes
|
||||
self.device = device
|
||||
self.level = level
|
||||
return self.shard_fn
|
||||
|
||||
|
||||
class _ListAppend(ListAppend_):
|
||||
"""
|
||||
A metafuncgraph class that append one element to list.
|
||||
|
|
|
@ -28,7 +28,7 @@ from .primitive import Primitive
|
|||
from . import operations as P
|
||||
from .operations import _grad_ops
|
||||
from .operations import _csr_ops
|
||||
from .composite import _Grad
|
||||
from .composite import _Grad, Shard
|
||||
from .._c_expression import security
|
||||
|
||||
typeof = Primitive('typeof')
|
||||
|
@ -338,6 +338,10 @@ def vjp(fn, inputs, v):
|
|||
return wrap_container(*inputs, v)
|
||||
return wrap_container(inputs, v)
|
||||
|
||||
shard_fn = Shard()
|
||||
def shard(fn, in_axes, out_axes, device, level=0):
|
||||
return shard_fn(fn, in_axes, out_axes, device, level)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _raise_type_error():
|
||||
|
|
|
@ -322,4 +322,52 @@ TEST_F(TestComposite, test_ZipOperation) {
|
|||
size_t expect = 3;
|
||||
ASSERT_EQ(real, expect);
|
||||
}
|
||||
|
||||
/// Feature: Shard operation.
|
||||
/// Description: Test the func_graph generation of Shard op and the inference of the Shard caller.
|
||||
/// Expectation: Generate and the infer successfully.
|
||||
TEST_F(TestComposite, test_shard) {
|
||||
// Make origin func_graph which includes a relu node.
|
||||
FuncGraphPtr origin_func_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimRelu));
|
||||
inputs.push_back(origin_func_graph->add_parameter());
|
||||
CNodePtr relu = origin_func_graph->NewCNode(inputs);
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(relu);
|
||||
CNodePtr origin_return = origin_func_graph->NewCNode(inputs);
|
||||
origin_func_graph->set_return(origin_return);
|
||||
|
||||
// Make the func_graph which includes a Shard meta_func_graph.
|
||||
FuncGraphPtr shard_func_graph = std::make_shared<FuncGraph>();
|
||||
MetaFuncGraphPtr shard_op = std::make_shared<prim::Shard>("shard_op");
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(shard_op));
|
||||
inputs.push_back(NewValueNode(origin_func_graph));
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
inputs.push_back(NewValueNode(MakeValue(0)));
|
||||
}
|
||||
CNodePtr shard = shard_func_graph->NewCNode(inputs);
|
||||
inputs.clear();
|
||||
inputs.push_back(shard);
|
||||
inputs.push_back(shard_func_graph->add_parameter());
|
||||
CNodePtr shard_user = shard_func_graph->NewCNode(inputs);
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(shard_user);
|
||||
CNodePtr shard_return = shard_func_graph->NewCNode(inputs);
|
||||
shard_func_graph->set_return(shard_return);
|
||||
|
||||
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||
AbstractBasePtrList args_spec_list = {tensor};
|
||||
|
||||
auto ret = engine_->Run(shard_func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_NE(ret, nullptr);
|
||||
ASSERT_TRUE(ret->isa<abstract::AbstractTensor>());
|
||||
auto build_shape = ret->BuildShape();
|
||||
EXPECT_TRUE(build_shape->isa<abstract::Shape>());
|
||||
auto shape = build_shape->cast<abstract::ShapePtr>();
|
||||
ASSERT_EQ(shape->shape(), std::vector<int64_t>({2, 3, 4}));
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue