forked from mindspore-Ecosystem/mindspore
Optimize MixedPrecisionCast function
This commit is contained in:
parent
394cd073ae
commit
041f628cce
|
@ -242,6 +242,7 @@ const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
|
|||
const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
|
||||
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
|
||||
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
|
||||
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
|
||||
// Comm ops
|
||||
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
|
|
|
@ -251,6 +251,7 @@ extern const PrimitivePtr kPrimIs_;
|
|||
extern const PrimitivePtr kPrimIsNot;
|
||||
extern const PrimitivePtr kPrimInDict;
|
||||
extern const PrimitivePtr kPrimNotInDict;
|
||||
extern const PrimitivePtr kPrimMixedPrecisionCast;
|
||||
|
||||
// Comm ops
|
||||
extern const PrimitivePtr kPrimMirror;
|
||||
|
|
|
@ -67,7 +67,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
|
|||
} else {
|
||||
return param;
|
||||
}
|
||||
auto cast_helper = prim::GetPythonOps("_mp_cast_helper", "mindspore.ops.composite.base");
|
||||
auto cast_helper = prim::kPrimMixedPrecisionCast;
|
||||
auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param});
|
||||
return cast;
|
||||
}
|
||||
|
|
|
@ -147,9 +147,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
if (!prim_->isa<prim::DoSignaturePrimitive>()) {
|
||||
MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString();
|
||||
}
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
@ -221,9 +218,6 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
if (!prim_->isa<prim::UnpackGraphPrimitive>()) {
|
||||
MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString();
|
||||
}
|
||||
|
||||
auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
|
||||
auto out_node = out_conf->node()->cast<CNodePtr>();
|
||||
|
@ -267,6 +261,63 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node_type, AnfNodePtr target_type,
|
||||
FuncGraphPtr func_graph) {
|
||||
AnfNodePtr target_node = source_node;
|
||||
if (node_type->isa<AbstractTensor>()) {
|
||||
auto x = node_type->cast<AbstractTensorPtr>();
|
||||
if (x->element()->BuildType()->isa<Float>()) {
|
||||
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type});
|
||||
}
|
||||
} else if (node_type->isa<AbstractTuple>()) {
|
||||
auto x = node_type->cast<AbstractTuplePtr>();
|
||||
auto &items = x->elements();
|
||||
std::size_t size = items.size();
|
||||
std::vector<AnfNodePtr> nodes;
|
||||
nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (int i = 0; i < SizeToInt(size); i++) {
|
||||
AnfNodePtr tuple_node =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(i)});
|
||||
AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, items[i], target_type, func_graph);
|
||||
nodes.emplace_back(node);
|
||||
}
|
||||
target_node = func_graph->NewCNode(nodes);
|
||||
}
|
||||
return target_node;
|
||||
}
|
||||
|
||||
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
auto out_node = out_conf->node()->cast<CNodePtr>();
|
||||
const auto &out_node_inputs = out_node->inputs();
|
||||
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "MixedPrecisionCast"
|
||||
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
|
||||
<< ", inputs size " << out_node_inputs.size();
|
||||
}
|
||||
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
|
||||
|
||||
ScopePtr scope = kDefaultScope;
|
||||
if (out_conf != nullptr) {
|
||||
scope = out_conf->node()->scope();
|
||||
}
|
||||
ScopeGuard scope_guard(scope);
|
||||
|
||||
FuncGraphPtr func_graph = out_conf->node()->func_graph();
|
||||
AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph);
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
|
||||
|
||||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
namespace {
|
||||
py::object BuildValue(const ValuePtr &value_ptr) {
|
||||
if (value_ptr == nullptr) {
|
||||
|
|
|
@ -102,6 +102,22 @@ class UnpackGraphEvaluator : public Evaluator {
|
|||
PrimitivePtr prim_;
|
||||
};
|
||||
|
||||
class MixedPrecisionCastEvaluator : public Evaluator {
|
||||
public:
|
||||
explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive)
|
||||
: Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {}
|
||||
~MixedPrecisionCastEvaluator() override = default;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
AnfNodeConfigPtr out_config = nullptr) override;
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
private:
|
||||
PrimitivePtr prim_;
|
||||
};
|
||||
|
||||
bool IsInWhiteList(PrimitivePtr primitive);
|
||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||
|
||||
|
|
|
@ -308,6 +308,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
|
||||
return evaluator;
|
||||
}
|
||||
if (prim->name() == prim::kPrimMixedPrecisionCast->name()) {
|
||||
evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim);
|
||||
return evaluator;
|
||||
}
|
||||
if (prim->HasPyEvaluator()) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
if (prim_py != nullptr) {
|
||||
|
|
|
@ -21,7 +21,6 @@ from ...common.parameter import Parameter, ParameterTuple
|
|||
from ...ops import composite as C
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.composite.base import _mp_cast_helper
|
||||
from ...ops.operations.comm_ops import _VirtualDataset
|
||||
from ..cell import Cell
|
||||
from .grad_reducer import DistributedGradReducer
|
||||
|
@ -345,7 +344,7 @@ class WithEvalCell(Cell):
|
|||
def construct(self, data, label):
|
||||
outputs = self._network(data)
|
||||
if self.add_cast_fp32:
|
||||
label = _mp_cast_helper(mstype.float32, label)
|
||||
label = F.mixed_precision_cast(mstype.float32, label)
|
||||
outputs = F.cast(outputs, mstype.float32)
|
||||
loss = self._loss_fn(outputs, label)
|
||||
return loss, outputs, label
|
||||
|
|
|
@ -24,7 +24,6 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeF
|
|||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_exec
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ...common.parameter import Parameter
|
||||
|
||||
|
||||
|
@ -297,32 +296,3 @@ env_get = MultitypeFuncGraph("env_get")
|
|||
def _tensor_env_get(env, parameter):
|
||||
"""Used to get env."""
|
||||
return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter))
|
||||
|
||||
|
||||
_mp_cast_helper = MultitypeFuncGraph('mixed_precision_cast_helper')
|
||||
|
||||
|
||||
@_mp_cast_helper.register("TypeType", "Number")
|
||||
@core
|
||||
def _mixed_precision_cast_helper_1(type_, x):
|
||||
"""if x is float cast to type."""
|
||||
# type_ is place holder
|
||||
return x
|
||||
|
||||
|
||||
@_mp_cast_helper.register("TypeType", "Tensor")
|
||||
@core
|
||||
def _mixed_precision_cast_helper_2(type_, x):
|
||||
"""if x is float cast to type."""
|
||||
if F.issubclass_(F.dtype(x), mstype.float_):
|
||||
return P.Cast()(x, type_)
|
||||
return x
|
||||
|
||||
@_mp_cast_helper.register("TypeType", "Tuple")
|
||||
@core
|
||||
def _mixed_precision_cast_helper_3(type_, x):
|
||||
"""if x is a tuple"""
|
||||
t = ()
|
||||
for item in x:
|
||||
t = t + (_mp_cast_helper(type_, item),)
|
||||
return t
|
||||
|
|
|
@ -126,6 +126,7 @@ is_ = Primitive("is_")
|
|||
is_not = Primitive("is_not")
|
||||
in_dict = Primitive("in_dict")
|
||||
not_in_dict = Primitive("not_in_dict")
|
||||
mixed_precision_cast = Primitive("mixed_precision_cast")
|
||||
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
||||
dot = Primitive('dot')
|
||||
array_reduce = Primitive('array_reduce')
|
||||
|
|
|
@ -21,7 +21,6 @@ from .._checkparam import Rel
|
|||
from ..common import dtype as mstype
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from ..ops import functional as F
|
||||
from ..ops.composite.base import _mp_cast_helper
|
||||
from ..parallel._utils import _get_parallel_mode
|
||||
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
||||
from .parallel_utils import ParallelMode
|
||||
|
@ -98,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|||
|
||||
def construct(self, data, label):
|
||||
out = self._backbone(data)
|
||||
label = _mp_cast_helper(mstype.float32, label)
|
||||
label = F.mixed_precision_cast(mstype.float32, label)
|
||||
return self._loss_fn(F.cast(out, mstype.float32), label)
|
||||
|
||||
validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
|
||||
|
|
Loading…
Reference in New Issue