forked from mindspore-Ecosystem/mindspore
!767 [bug]with eval cell show cast is not support in gpu pynative
Merge pull request !767 from vlne-v1/I1F48Y-with-eval-cell-show-cast-is-not-support-in-gpu-pynative
This commit is contained in:
commit
64b3b566ab
|
@ -35,7 +35,6 @@
|
|||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
|
||||
// Expand the tuple and dict parameters generated when parsing the function call,
|
||||
// and generate positional parameters and key-value pairs for function.
|
||||
class UnpackCall : public MetaFuncGraph {
|
||||
|
@ -47,7 +46,6 @@ class UnpackCall : public MetaFuncGraph {
|
|||
friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using UnpackCallPtr = std::shared_ptr<UnpackCall>;
|
||||
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -300,6 +300,10 @@ void ExecutorPy::SaveCompiledGraphToPb(const std::string &phase_s) {
|
|||
// save the graph to file in protobuf format
|
||||
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (phase_s.empty()) {
|
||||
MS_LOG(ERROR) << "`phase` is empty '" << phase_s << "'!";
|
||||
return;
|
||||
}
|
||||
std::string name_prefix = phase_s.substr(0, phase_s.find("."));
|
||||
std::string pb_filename = std::string("ms_output_") + name_prefix + ".pb";
|
||||
std::string filename = GetFilePathName(pb_filename);
|
||||
|
|
|
@ -304,15 +304,19 @@ class WithEvalCell(Cell):
|
|||
>>> eval_net = nn.WithEvalCell(net, loss_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, network, loss_fn):
|
||||
def __init__(self, network, loss_fn, add_cast_fp32=False):
|
||||
super(WithEvalCell, self).__init__(auto_prefix=False)
|
||||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self.add_cast_fp32 = add_cast_fp32
|
||||
|
||||
|
||||
def construct(self, data, label):
|
||||
outputs = self._network(data)
|
||||
label = _mp_cast_helper(mstype.float32, label)
|
||||
loss = self._loss_fn(F.cast(outputs, mstype.float32), label)
|
||||
if self.add_cast_fp32:
|
||||
label = _mp_cast_helper(mstype.float32, label)
|
||||
outputs = F.cast(outputs, mstype.float32)
|
||||
loss = self._loss_fn(outputs, label)
|
||||
return loss, outputs, label
|
||||
|
||||
|
||||
|
|
|
@ -162,7 +162,7 @@ class Model:
|
|||
else:
|
||||
if self._loss_fn is None:
|
||||
raise ValueError("loss_fn can not be None.")
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn)
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
def _build_predict_network(self):
|
||||
|
|
Loading…
Reference in New Issue