!5276 Fix the problem of resource clear v2 in r0.7

Merge pull request !5276 from Simson/fix-r07
This commit is contained in:
mindspore-ci-bot 2020-08-26 21:34:01 +08:00 committed by Gitee
commit 57e131a136
5 changed files with 49 additions and 24 deletions

View File

@ -263,12 +263,13 @@ void ExecutorPy::DelNetRes(const std::string &id) {
if (executor_ != nullptr) {
bool flag = false;
auto tmp_info = info_;
for (auto &item : tmp_info) {
if (item.first.find(id) != string::npos) {
MS_LOG(DEBUG) << "Delete network res:" << item.first;
item.second = nullptr;
(void)info_.erase(item.first);
for (auto it = tmp_info.begin(); it != tmp_info.end();) {
if (it->first.find(id) != std::string::npos) {
it->second = nullptr;
it = tmp_info.erase(it);
flag = true;
} else {
it++;
}
}

View File

@ -130,6 +130,9 @@ static std::string GetId(const py::object &obj) {
}
return prefix + key;
}
if (py::isinstance<py::str>(to_process)) {
return prefix + std::string(py::str(to_process));
}
if (py::isinstance<py::int_>(to_process)) {
return prefix + std::string(py::str(to_process));
}
@ -1253,17 +1256,24 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
pipeline::ReclaimOptimizer();
}
template <typename T>
void MapClear(T map, const std::string &flag) {
for (auto it = map.begin(); it != map.end();) {
if (it->first.find(flag) != std::string::npos) {
it->second = nullptr;
it = map.erase(it);
} else {
it++;
}
}
}
void PynativeExecutor::Clear(const std::string &flag) {
if (!flag.empty()) {
MS_LOG(DEBUG) << "Clear res";
auto key_value = std::find_if(graph_map_.begin(), graph_map_.end(),
[&flag](const auto &item) { return item.first.find(flag) != std::string::npos; });
if (key_value != graph_map_.end()) {
std::string key = key_value->first;
(void)graph_map_.erase(key);
(void)cell_graph_map_.erase(key);
(void)cell_resource_map_.erase(key);
}
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(graph_map_, flag);
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(cell_graph_map_, flag);
MapClear<std::unordered_map<std::string, ResourcePtr>>(cell_resource_map_, flag);
Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance();
@ -1281,7 +1291,6 @@ void PynativeExecutor::Clear(const std::string &flag) {
curr_g_ = nullptr;
graph_info_map_.clear();
op_id_map_.clear();
// node_abs_map_.clear();
std::stack<FuncGraphPtr>().swap(graph_p_);
ConfigManager::GetInstance().ResetIterNum();
}
@ -1295,7 +1304,18 @@ void PynativeExecutor::Clean() {
pipeline::ReclaimOptimizer();
}
template <typename T>
void MapErase(T map) {
for (auto it = map.begin(); it != map.end();) {
it = map.erase(it++);
}
}
void PynativeExecutor::ClearRes() {
MapErase<std::unordered_map<std::string, FuncGraphPtr>>(graph_map_);
MapErase<std::unordered_map<std::string, FuncGraphPtr>>(cell_graph_map_);
MapErase<std::unordered_map<std::string, ResourcePtr>>(cell_resource_map_);
MapErase<std::unordered_map<std::string, abstract::AbstractBasePtr>>(node_abs_map_);
Clean();
resource_.reset();
}

View File

@ -102,13 +102,17 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
py::object grad_dtype = grads[i].attr("dtype");
py::tuple arg_shape = py_args[i].attr("shape");
py::object arg_dtype = py_args[i].attr("dtype");
if (!grad_shape.equal(arg_shape) || !grad_dtype.is(arg_dtype)) {
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i
<< "th arg should have the same shape and dtype as the " << i << "th arg, but the "
<< i << "th arg shape: " << py::cast<py::str>(arg_shape)
<< " and dtype: " << py::cast<py::str>(arg_dtype)
<< ", the gradient shape: " << py::cast<py::str>(grad_shape)
<< " and dtype: " << py::cast<py::str>(grad_dtype) << ".";
if (!grad_shape.equal(arg_shape)) {
MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same shape as the " << i << "th arg, but the " << i
<< "th arg shape is: " << py::cast<py::str>(arg_shape)
<< ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
}
if (!grad_dtype.is(arg_dtype)) {
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
}
}
}

View File

@ -227,6 +227,7 @@ def dtype_to_pytype(type_):
return {
bool_: bool,
int_: int,
int8: int,
int16: int,
int32: int,
@ -235,6 +236,7 @@ def dtype_to_pytype(type_):
uint16: int,
uint32: int,
uint64: int,
float_: float,
float16: float,
float32: float,
float64: float,

View File

@ -113,7 +113,6 @@ def test_user_define_bprop_check_shape():
grad_net = GradNet(net)
with pytest.raises(ValueError) as ex:
ret = grad_net(x, sens)
assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value)
def test_user_define_bprop_check_dtype():
@ -142,9 +141,8 @@ def test_user_define_bprop_check_dtype():
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True)
net = Net()
grad_net = GradNet(net)
with pytest.raises(ValueError) as ex:
with pytest.raises(TypeError) as ex:
ret = grad_net(x, sens)
assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value)
def test_user_define_bprop_check_parameter():