!27461 Add shard operation

Merge pull request !27461 from YuJianfeng/shard
This commit is contained in:
i-robot 2021-12-16 02:36:13 +00:00 committed by Gitee
commit d1d516e668
16 changed files with 331 additions and 4 deletions

View File

@ -1115,5 +1115,76 @@ REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
*m, "TupleGetItemTensor_")
.def(py::init<std::string &>());
}));
namespace {
FuncGraphPtr GetShard(const AnfNodePtr &shard, const std::vector<AnfNodePtr> &origin_graph_params) {
FuncGraphPtr shard_child = std::make_shared<FuncGraph>();
shard_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
std::vector<AnfNodePtr> inputs;
inputs.reserve(origin_graph_params.size() + 1);
(void)inputs.emplace_back(shard);
for (size_t i = 0; i < origin_graph_params.size(); ++i) {
(void)inputs.emplace_back(shard_child->add_parameter());
}
auto shard_app = shard_child->NewCNodeInOrder(std::move(inputs));
shard_child->set_output(shard_app);
return shard_child;
}
} // namespace
FuncGraphPtr Shard::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
constexpr size_t shard_input_size = 5;
if (args_spec_list.size() != shard_input_size) {
MS_LOG(EXCEPTION) << "'Shard' requires " << shard_input_size
<< " inputs. Includes a Cell or function, in_axes, out_axes, device and level.";
}
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "'Shard' arg0 must be a 'Function' or 'Cell', but got " << args_spec_list[0]->ToString()
<< ".";
}
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
MS_EXCEPTION_IF_NULL(real_fn);
FuncGraphPtr origin_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(origin_graph);
origin_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
FuncGraphPtr shard_fg = nullptr;
{
TraceGuard g(std::make_shared<TraceShard>(origin_graph->debug_info()));
shard_fg = std::make_shared<FuncGraph>();
}
// Create the debug info
auto parameter_size = origin_graph->parameters().size();
std::ostringstream ss;
ss << "shard{" << parameter_size << "}";
shard_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
shard_fg->debug_info()->set_name(ss.str());
// Make the Shard node.
std::vector<AnfNodePtr> inputs;
inputs.reserve(args_spec_list.size() + 1);
(void)inputs.emplace_back(NewValueNode(prim::kPrimShard));
for (size_t i = 0; i < args_spec_list.size(); ++i) {
(void)inputs.emplace_back(shard_fg->add_parameter());
}
auto shard = shard_fg->NewCNodeInOrder(std::move(inputs));
FuncGraphPtr shard_child = nullptr;
{
TraceGuard guard(std::make_shared<TraceShard>(shard_fg->debug_info()));
shard_child = GetShard(shard, origin_graph->parameters());
}
shard_fg->set_output(NewValueNode(shard_child));
return shard_fg;
}
REGISTER_PYBIND_DEFINE(Shard_, ([](const py::module *m) {
(void)py::class_<Shard, MetaFuncGraph, std::shared_ptr<Shard>>(*m, "Shard_")
.def(py::init<std::string &>(), py::arg("fn"));
}));
} // namespace prim
} // namespace mindspore

View File

@ -213,6 +213,23 @@ class TupleGetItemTensor : public MetaFuncGraph {
}
};
using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
class Shard : public MetaFuncGraph {
public:
explicit Shard(const string &name) : MetaFuncGraph(name) {
signatures_ =
// def shard(func:read, weight_list:read, in_axes:read, out_axes:read, device:read, level:read):
std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
{"in_axes", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
{"out_axes", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
{"device", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
{"level", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}});
}
~Shard() override = default;
MS_DECLARE_PARENT(Shard, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
};
} // namespace prim
} // namespace mindspore

View File

@ -714,6 +714,26 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return AbstractFunction::MakeAbstractFunction(jv);
}
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: func, in_axes, out_axes, device, level.
constexpr size_t shard_input_size = 5;
CheckArgsSize(primitive->name(), args_spec_list, shard_input_size);
MS_LOG(DEBUG) << "Evaluate Shard: " << args_spec_list[0]->ToString();
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
MS_EXCEPTION_IF_NULL(x);
AbstractFuncAtomPtrList shard_v;
auto build_shard_v = [&shard_v](const AbstractFuncAtomPtr &func) {
auto shard_closure = std::make_shared<ShardTransformedAbstractClosure>(func);
shard_v.push_back(shard_closure);
};
x->Visit(build_shard_v);
return AbstractFunction::MakeAbstractFunction(shard_v);
}
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
@ -779,6 +799,7 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferI
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
InferImplBroadcastGradientArgs, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeSiice, prim::kPrimMakeSlice, InferImplMakeSlice, nullptr);

View File

@ -55,6 +55,8 @@ AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -523,6 +523,28 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
return res;
}
EvalResultPtr ShardEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result != nullptr) {
return eval_result;
}
// Call the original evaluator, get the result: y = f(x)
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
MS_EXCEPTION_IF_NULL(result);
auto res = std::make_shared<EvalResult>(result->abstract(), std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, res);
return res;
}
EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) {
if (args_spec_list.size() != args_spec_list_.size()) {

View File

@ -357,6 +357,41 @@ class JEvaluator : public Evaluator {
AbstractFunctionPtr orig_func_;
};
class ShardEvaluator : public Evaluator {
public:
ShardEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
: Evaluator("ShardEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {}
~ShardEvaluator() override = default;
MS_DECLARE_PARENT(ShardEvaluator, Evaluator);
AnfNodePtr bound_node() const override {
if (evaluator_ != nullptr) {
return evaluator_->bound_node();
}
return bound_node_.lock();
}
void set_bound_node(const AnfNodePtr &node) override {
if (evaluator_ != nullptr) {
evaluator_->set_bound_node(node);
}
bound_node_ = AnfNodeWeakPtr(node);
}
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
}
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
private:
EvaluatorPtr evaluator_;
AbstractFunctionPtr orig_func_;
};
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args);
} // namespace abstract
} // namespace mindspore

View File

@ -458,6 +458,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformed
return jevaluator;
}
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func);
AbstractFunctionPtr func_orig = func->fn();
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
auto shard_evaluator = std::make_shared<ShardEvaluator>(evaluator_orig, func_orig);
return shard_evaluator;
}
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func);
std::shared_ptr<VirtualEvaluator> virtual_evaluator =
@ -495,6 +503,8 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<MetaFuncGraphAbstractClosure>>());
} else if (func->isa<JTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
} else if (func->isa<ShardTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<ShardTransformedAbstractClosure>>());
} else if (func->isa<VirtualAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<VirtualAbstractClosure>>());
} else if (func->isa<PartialAbstractClosure>()) {

View File

@ -289,6 +289,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &fn);
FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; }
const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; }

View File

@ -254,6 +254,20 @@ std::size_t JTransformedAbstractClosure::hash() const {
return hash_value;
}
bool ShardTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<ShardTransformedAbstractClosure>()) {
return false;
}
auto other_transformed = static_cast<const ShardTransformedAbstractClosure *>(&other);
return fn_ == other_transformed->fn_;
}
std::size_t ShardTransformedAbstractClosure::hash() const {
MS_EXCEPTION_IF_NULL(fn_);
auto hash_value = hash_combine(tid(), fn_->hash());
return hash_value;
}
bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<VirtualAbstractClosure>()) {
return false;

View File

@ -322,6 +322,36 @@ class MS_CORE_API JTransformedAbstractClosure final : public AbstractFuncAtom {
AbstractFuncAtomPtr fn_;
};
/// \brief ShardTransformedAbstractClosure defines interface for abstract of Function
/// transformed through the application of Shard.
class MS_CORE_API ShardTransformedAbstractClosure final : public AbstractFuncAtom {
public:
/// \brief Constructor of ShardTransformedAbstractClosure
///
/// \param[in] fn The AbstractFuncAtom transformed through the application of Shard.
explicit ShardTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
/// \brief Destructor of ShardTransformedAbstractClosure
~ShardTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(ShardTransformedAbstractClosure, AbstractFuncAtom)
/// \brief Get the AbstractFuncAtom ShardTransformedAbstractClosure corresponding to.
///
/// \return The AbstractFuncAtom ShardTransformedAbstractClosure corresponding to.
AbstractFuncAtomPtr fn() { return fn_; }
AbstractFunctionPtr Copy() const override { return std::make_shared<ShardTransformedAbstractClosure>(fn_); }
bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override;
std::string ToString() const override { return "Shard(" + fn_->ToString() + ")"; }
private:
AbstractFuncAtomPtr fn_;
};
/// \brief VirtualAbstractClosure defines interface for function with an explicitly
/// fixed type signature.
class MS_CORE_API VirtualAbstractClosure final : public AbstractFuncAtom {

View File

@ -686,6 +686,7 @@ inline const PrimitivePtr kPrimPyInterpret = std::make_shared<Primitive>("PyInte
// Other primitive not used by backend but used in core;
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J", kSideEffectPropagate);
inline const PrimitivePtr kPrimShard = std::make_shared<Primitive>("Shard", kSideEffectPropagate);
// Used to build graph which have keyword arguments
inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");

View File

@ -402,6 +402,14 @@ class TraceMixedPrecision : public TraceInfo {
~TraceMixedPrecision() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceMixedPrecision>(*this); }
};
class TraceShard : public TraceInfo {
public:
explicit TraceShard(const DebugInfoPtr &info) : TraceInfo(info) {}
~TraceShard() override = default;
std::string name() const override { return "shard_ops"; }
TraceInfoPtr clone() override { return std::make_shared<TraceShard>(*this); }
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_

View File

@ -21,7 +21,7 @@ Pre-defined combination of operators.
from .base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
core, env_get, tail, zip_operation
core, env_get, tail, zip_operation, Shard
from .clip_ops import clip_by_value, clip_by_global_norm
from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like
@ -60,4 +60,5 @@ __all__ = [
'repeat_elements',
'sequence_mask',
'matmul',
'_Grad']
'_Grad',
'Shard']

View File

@ -20,7 +20,7 @@ from functools import partial
from types import FunctionType
from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_executor, _wrap_func
@ -735,6 +735,48 @@ class Map(Map_):
return tuple(map(func, *args_list))
class Shard(Shard_):
"""Shard operation"""
def __init__(self):
"""Initialize Shard."""
Shard_.__init__(self, 'Shard')
self.shard_fn = None
self.fn = None
self.in_axes = None
self.out_axes = None
self.device = None
self.level = None
def __call__(self, fn, in_axes, out_axes, device, level=0):
if not isinstance(in_axes, tuple):
raise TypeError(f"For 'Shard', the 'in_axes' should be a tuple, but got {type(in_axes).__name__}")
if not isinstance(out_axes, tuple):
raise TypeError(f"For 'Shard', the 'out_axes' should be a tuple, "
f"but got {type(out_axes).__name__}")
if not isinstance(device, str):
raise TypeError(f"For 'Shard', the 'device' should be a string, "
f"but got {type(device).__name__}")
if not isinstance(level, int):
raise TypeError(f"For 'Shard', the 'level' should be an integer, "
f"but got {type(level).__name__}")
if self.shard_fn is not None and self.fn == fn and self.in_axes == in_axes and self.out_axes == out_axes and \
self.device == device and self.level == level:
return self.shard_fn
shard_ = Shard()
@ms_function
def after_shard(*args):
return shard_(fn, in_axes, out_axes, device, level)(*args)
self.shard_fn = after_shard
self.fn = fn
self.in_axes = in_axes
self.out_axes = out_axes
self.device = device
self.level = level
return self.shard_fn
class _ListAppend(ListAppend_):
"""
A metafuncgraph class that append one element to list.

View File

@ -28,7 +28,7 @@ from .primitive import Primitive
from . import operations as P
from .operations import _grad_ops
from .operations import _csr_ops
from .composite import _Grad
from .composite import _Grad, Shard
from .._c_expression import security
typeof = Primitive('typeof')
@ -338,6 +338,10 @@ def vjp(fn, inputs, v):
return wrap_container(*inputs, v)
return wrap_container(inputs, v)
shard_fn = Shard()
def shard(fn, in_axes, out_axes, device, level=0):
return shard_fn(fn, in_axes, out_axes, device, level)
@constexpr
def _raise_type_error():

View File

@ -322,4 +322,52 @@ TEST_F(TestComposite, test_ZipOperation) {
size_t expect = 3;
ASSERT_EQ(real, expect);
}
/// Feature: Shard operation.
/// Description: Test the func_graph generation of Shard op and the inference of the Shard caller.
/// Expectation: Generate and the infer successfully.
TEST_F(TestComposite, test_shard) {
// Make origin func_graph which includes a relu node.
FuncGraphPtr origin_func_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimRelu));
inputs.push_back(origin_func_graph->add_parameter());
CNodePtr relu = origin_func_graph->NewCNode(inputs);
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(relu);
CNodePtr origin_return = origin_func_graph->NewCNode(inputs);
origin_func_graph->set_return(origin_return);
// Make the func_graph which includes a Shard meta_func_graph.
FuncGraphPtr shard_func_graph = std::make_shared<FuncGraph>();
MetaFuncGraphPtr shard_op = std::make_shared<prim::Shard>("shard_op");
inputs.clear();
inputs.push_back(NewValueNode(shard_op));
inputs.push_back(NewValueNode(origin_func_graph));
for (size_t i = 0; i < 4; ++i) {
inputs.push_back(NewValueNode(MakeValue(0)));
}
CNodePtr shard = shard_func_graph->NewCNode(inputs);
inputs.clear();
inputs.push_back(shard);
inputs.push_back(shard_func_graph->add_parameter());
CNodePtr shard_user = shard_func_graph->NewCNode(inputs);
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(shard_user);
CNodePtr shard_return = shard_func_graph->NewCNode(inputs);
shard_func_graph->set_return(shard_return);
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
AbstractBasePtrList args_spec_list = {tensor};
auto ret = engine_->Run(shard_func_graph, args_spec_list).inferred->abstract();
ASSERT_NE(ret, nullptr);
ASSERT_TRUE(ret->isa<abstract::AbstractTensor>());
auto build_shape = ret->BuildShape();
EXPECT_TRUE(build_shape->isa<abstract::Shape>());
auto shape = build_shape->cast<abstract::ShapePtr>();
ASSERT_EQ(shape->shape(), std::vector<int64_t>({2, 3, 4}));
}
} // namespace mindspore