forked from mindspore-Ecosystem/mindspore
!5276 Fix the problem of resource clear v2 in r0.7
Merge pull request !5276 from Simson/fix-r07
This commit is contained in:
commit
57e131a136
|
@ -263,12 +263,13 @@ void ExecutorPy::DelNetRes(const std::string &id) {
|
|||
if (executor_ != nullptr) {
|
||||
bool flag = false;
|
||||
auto tmp_info = info_;
|
||||
for (auto &item : tmp_info) {
|
||||
if (item.first.find(id) != string::npos) {
|
||||
MS_LOG(DEBUG) << "Delete network res:" << item.first;
|
||||
item.second = nullptr;
|
||||
(void)info_.erase(item.first);
|
||||
for (auto it = tmp_info.begin(); it != tmp_info.end();) {
|
||||
if (it->first.find(id) != std::string::npos) {
|
||||
it->second = nullptr;
|
||||
it = tmp_info.erase(it);
|
||||
flag = true;
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -130,6 +130,9 @@ static std::string GetId(const py::object &obj) {
|
|||
}
|
||||
return prefix + key;
|
||||
}
|
||||
if (py::isinstance<py::str>(to_process)) {
|
||||
return prefix + std::string(py::str(to_process));
|
||||
}
|
||||
if (py::isinstance<py::int_>(to_process)) {
|
||||
return prefix + std::string(py::str(to_process));
|
||||
}
|
||||
|
@ -1253,17 +1256,24 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
|
|||
pipeline::ReclaimOptimizer();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MapClear(T map, const std::string &flag) {
|
||||
for (auto it = map.begin(); it != map.end();) {
|
||||
if (it->first.find(flag) != std::string::npos) {
|
||||
it->second = nullptr;
|
||||
it = map.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PynativeExecutor::Clear(const std::string &flag) {
|
||||
if (!flag.empty()) {
|
||||
MS_LOG(DEBUG) << "Clear res";
|
||||
auto key_value = std::find_if(graph_map_.begin(), graph_map_.end(),
|
||||
[&flag](const auto &item) { return item.first.find(flag) != std::string::npos; });
|
||||
if (key_value != graph_map_.end()) {
|
||||
std::string key = key_value->first;
|
||||
(void)graph_map_.erase(key);
|
||||
(void)cell_graph_map_.erase(key);
|
||||
(void)cell_resource_map_.erase(key);
|
||||
}
|
||||
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(graph_map_, flag);
|
||||
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(cell_graph_map_, flag);
|
||||
MapClear<std::unordered_map<std::string, ResourcePtr>>(cell_resource_map_, flag);
|
||||
Clean();
|
||||
// Maybe exit in the pynative runing op, so need reset pynative flag.
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -1281,7 +1291,6 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
curr_g_ = nullptr;
|
||||
graph_info_map_.clear();
|
||||
op_id_map_.clear();
|
||||
// node_abs_map_.clear();
|
||||
std::stack<FuncGraphPtr>().swap(graph_p_);
|
||||
ConfigManager::GetInstance().ResetIterNum();
|
||||
}
|
||||
|
@ -1295,7 +1304,18 @@ void PynativeExecutor::Clean() {
|
|||
pipeline::ReclaimOptimizer();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MapErase(T map) {
|
||||
for (auto it = map.begin(); it != map.end();) {
|
||||
it = map.erase(it++);
|
||||
}
|
||||
}
|
||||
|
||||
void PynativeExecutor::ClearRes() {
|
||||
MapErase<std::unordered_map<std::string, FuncGraphPtr>>(graph_map_);
|
||||
MapErase<std::unordered_map<std::string, FuncGraphPtr>>(cell_graph_map_);
|
||||
MapErase<std::unordered_map<std::string, ResourcePtr>>(cell_resource_map_);
|
||||
MapErase<std::unordered_map<std::string, abstract::AbstractBasePtr>>(node_abs_map_);
|
||||
Clean();
|
||||
resource_.reset();
|
||||
}
|
||||
|
|
|
@ -102,13 +102,17 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
|||
py::object grad_dtype = grads[i].attr("dtype");
|
||||
py::tuple arg_shape = py_args[i].attr("shape");
|
||||
py::object arg_dtype = py_args[i].attr("dtype");
|
||||
if (!grad_shape.equal(arg_shape) || !grad_dtype.is(arg_dtype)) {
|
||||
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i
|
||||
<< "th arg should have the same shape and dtype as the " << i << "th arg, but the "
|
||||
<< i << "th arg shape: " << py::cast<py::str>(arg_shape)
|
||||
<< " and dtype: " << py::cast<py::str>(arg_dtype)
|
||||
<< ", the gradient shape: " << py::cast<py::str>(grad_shape)
|
||||
<< " and dtype: " << py::cast<py::str>(grad_dtype) << ".";
|
||||
if (!grad_shape.equal(arg_shape)) {
|
||||
MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
|
||||
<< "th arg should have the same shape as the " << i << "th arg, but the " << i
|
||||
<< "th arg shape is: " << py::cast<py::str>(arg_shape)
|
||||
<< ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
|
||||
}
|
||||
if (!grad_dtype.is(arg_dtype)) {
|
||||
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
|
||||
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
|
||||
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
|
||||
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -227,6 +227,7 @@ def dtype_to_pytype(type_):
|
|||
|
||||
return {
|
||||
bool_: bool,
|
||||
int_: int,
|
||||
int8: int,
|
||||
int16: int,
|
||||
int32: int,
|
||||
|
@ -235,6 +236,7 @@ def dtype_to_pytype(type_):
|
|||
uint16: int,
|
||||
uint32: int,
|
||||
uint64: int,
|
||||
float_: float,
|
||||
float16: float,
|
||||
float32: float,
|
||||
float64: float,
|
||||
|
|
|
@ -113,7 +113,6 @@ def test_user_define_bprop_check_shape():
|
|||
grad_net = GradNet(net)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
ret = grad_net(x, sens)
|
||||
assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value)
|
||||
|
||||
|
||||
def test_user_define_bprop_check_dtype():
|
||||
|
@ -142,9 +141,8 @@ def test_user_define_bprop_check_dtype():
|
|||
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True)
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
with pytest.raises(TypeError) as ex:
|
||||
ret = grad_net(x, sens)
|
||||
assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value)
|
||||
|
||||
|
||||
def test_user_define_bprop_check_parameter():
|
||||
|
|
Loading…
Reference in New Issue