Optimize MixedPrecisionCast function

This commit is contained in:
Kang 2020-06-02 20:32:25 +08:00
parent 394cd073ae
commit 041f628cce
10 changed files with 83 additions and 41 deletions

View File

@ -242,6 +242,7 @@ const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
// Comm ops
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");

View File

@ -251,6 +251,7 @@ extern const PrimitivePtr kPrimIs_;
extern const PrimitivePtr kPrimIsNot;
extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict;
extern const PrimitivePtr kPrimMixedPrecisionCast;
// Comm ops
extern const PrimitivePtr kPrimMirror;

View File

@ -67,7 +67,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
} else {
return param;
}
auto cast_helper = prim::GetPythonOps("_mp_cast_helper", "mindspore.ops.composite.base");
auto cast_helper = prim::kPrimMixedPrecisionCast;
auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param});
return cast;
}

View File

@ -147,9 +147,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
if (!prim_->isa<prim::DoSignaturePrimitive>()) {
MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString();
}
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
@ -221,9 +218,6 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
if (!prim_->isa<prim::UnpackGraphPrimitive>()) {
MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString();
}
auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
auto out_node = out_conf->node()->cast<CNodePtr>();
@ -267,6 +261,63 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
return engine->ForwardConfig(out_conf, fn_conf);
}
AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node_type, AnfNodePtr target_type,
FuncGraphPtr func_graph) {
AnfNodePtr target_node = source_node;
if (node_type->isa<AbstractTensor>()) {
auto x = node_type->cast<AbstractTensorPtr>();
if (x->element()->BuildType()->isa<Float>()) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
MS_EXCEPTION_IF_NULL(cast);
target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type});
}
} else if (node_type->isa<AbstractTuple>()) {
auto x = node_type->cast<AbstractTuplePtr>();
auto &items = x->elements();
std::size_t size = items.size();
std::vector<AnfNodePtr> nodes;
nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (int i = 0; i < SizeToInt(size); i++) {
AnfNodePtr tuple_node =
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(i)});
AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, items[i], target_type, func_graph);
nodes.emplace_back(node);
}
target_node = func_graph->NewCNode(nodes);
}
return target_node;
}
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
auto out_node = out_conf->node()->cast<CNodePtr>();
const auto &out_node_inputs = out_node->inputs();
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
MS_LOG(EXCEPTION) << "MixedPrecisionCast"
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
<< ", inputs size " << out_node_inputs.size();
}
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
}
ScopeGuard scope_guard(scope);
FuncGraphPtr func_graph = out_conf->node()->func_graph();
AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
return engine->ForwardConfig(out_conf, fn_conf);
}
namespace {
py::object BuildValue(const ValuePtr &value_ptr) {
if (value_ptr == nullptr) {

View File

@ -102,6 +102,22 @@ class UnpackGraphEvaluator : public Evaluator {
PrimitivePtr prim_;
};
class MixedPrecisionCastEvaluator : public Evaluator {
public:
explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive)
: Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {}
~MixedPrecisionCastEvaluator() override = default;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
}
private:
PrimitivePtr prim_;
};
bool IsInWhiteList(PrimitivePtr primitive);
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);

View File

@ -308,6 +308,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
return evaluator;
}
if (prim->name() == prim::kPrimMixedPrecisionCast->name()) {
evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim);
return evaluator;
}
if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) {

View File

@ -21,7 +21,6 @@ from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C
from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite.base import _mp_cast_helper
from ...ops.operations.comm_ops import _VirtualDataset
from ..cell import Cell
from .grad_reducer import DistributedGradReducer
@ -345,7 +344,7 @@ class WithEvalCell(Cell):
def construct(self, data, label):
outputs = self._network(data)
if self.add_cast_fp32:
label = _mp_cast_helper(mstype.float32, label)
label = F.mixed_precision_cast(mstype.float32, label)
outputs = F.cast(outputs, mstype.float32)
loss = self._loss_fn(outputs, label)
return loss, outputs, label

View File

@ -24,7 +24,6 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeF
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec
from .. import functional as F
from .. import operations as P
from ...common.parameter import Parameter
@ -297,32 +296,3 @@ env_get = MultitypeFuncGraph("env_get")
def _tensor_env_get(env, parameter):
"""Used to get env."""
return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter))
_mp_cast_helper = MultitypeFuncGraph('mixed_precision_cast_helper')
@_mp_cast_helper.register("TypeType", "Number")
@core
def _mixed_precision_cast_helper_1(type_, x):
"""if x is float cast to type."""
# type_ is place holder
return x
@_mp_cast_helper.register("TypeType", "Tensor")
@core
def _mixed_precision_cast_helper_2(type_, x):
"""if x is float cast to type."""
if F.issubclass_(F.dtype(x), mstype.float_):
return P.Cast()(x, type_)
return x
@_mp_cast_helper.register("TypeType", "Tuple")
@core
def _mixed_precision_cast_helper_3(type_, x):
"""if x is a tuple"""
t = ()
for item in x:
t = t + (_mp_cast_helper(type_, item),)
return t

View File

@ -126,6 +126,7 @@ is_ = Primitive("is_")
is_not = Primitive("is_not")
in_dict = Primitive("in_dict")
not_in_dict = Primitive("not_in_dict")
mixed_precision_cast = Primitive("mixed_precision_cast")
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce')

View File

@ -21,7 +21,6 @@ from .._checkparam import Rel
from ..common import dtype as mstype
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..ops import functional as F
from ..ops.composite.base import _mp_cast_helper
from ..parallel._utils import _get_parallel_mode
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from .parallel_utils import ParallelMode
@ -98,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
def construct(self, data, label):
out = self._backbone(data)
label = _mp_cast_helper(mstype.float32, label)
label = F.mixed_precision_cast(mstype.float32, label)
return self._loss_fn(F.cast(out, mstype.float32), label)
validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)