!36251 Fix pynative bug

Merge pull request !36251 from zjun/fix_bug2
This commit is contained in:
i-robot 2022-06-23 12:17:13 +00:00 committed by Gitee
commit 9dd9425121
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 61 additions and 45 deletions

View File

@ -175,6 +175,7 @@ AnfNodePtr BuildOnesLikeValue(const FuncGraphPtr &tape, const ValuePtr &out, con
// Build ones_like(out) as dout, shape is same with out.sens_value its id hold by pynative execute, which can be
// replace forward, but out is not.
if (ValueHasDynamicShape(out)) {
MS_EXCEPTION_IF_NULL(sens_value);
auto value_node = NewValueNode(sens_value);
auto value_node_abs = sens_value->ToAbstract()->Broaden();
MS_LOG(DEBUG) << "Sens value abstract " << value_node_abs->ToString();

View File

@ -2685,6 +2685,7 @@ void ForwardExecutor::SetFeedDynamicInputAbs(const py::object &cell, const py::a
<< PyObjToValue(args[i])->ToAbstract()->Broaden()->ToString()
<< ", dynamic abs: " << abs->ToString();
dynamic_shape_info_ptr()->obj_id_with_dynamic_output_abs[arg_id] = abs;
(void)node_abs_map_.erase(arg_id);
}
}
}
@ -3001,10 +3002,19 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
auto param_i = only_tensors[i];
const auto &param_i_value = PyObjToValue(param_i);
input_param_values.emplace_back(param_i_value);
auto param_i_abs = param_i_value->ToAbstract();
MS_EXCEPTION_IF_NULL(param_i_abs);
new_param->set_abstract(param_i_abs->Broaden());
const auto &param_i_id = GetId(param_i);
abstract::AbstractBasePtr param_i_abs = nullptr;
auto item = forward()->dynamic_shape_info_ptr()->obj_id_with_dynamic_output_abs.find(param_i_id);
if (item != forward()->dynamic_shape_info_ptr()->obj_id_with_dynamic_output_abs.end()) {
MS_LOG(DEBUG) << "Param " << i << " is dynamic input";
param_i_abs = item->second;
} else {
param_i_abs = param_i_value->ToAbstract();
MS_EXCEPTION_IF_NULL(param_i_abs);
param_i_abs = param_i_abs->Broaden();
}
MS_EXCEPTION_IF_NULL(param_i_abs);
new_param->set_abstract(param_i_abs);
SetTupleArgsToGraphInfoMap(curr_g(), param_i, new_param, true);
SetNodeMapInGraphInfoMap(curr_g(), param_i_id, new_param);
SetParamNodeMapInGraphInfoMap(curr_g(), param_i_id, new_param);
@ -3102,6 +3112,18 @@ void GradExecutor::NewGraphInner(const py::object *ret, const py::object &cell,
}
}
void GradExecutor::ChangeTopCellInfo(const TopCellInfoPtr &top_cell, const std::vector<ShapeVector> &new_args_shape) {
MS_EXCEPTION_IF_NULL(top_cell);
std::string new_cell_id = top_cell->cell_self_info()->cell_self_id;
for (size_t i = 0; i < new_args_shape.size(); ++i) {
new_cell_id += "_" + top_cell->cell_self_info()->args_shape[i]->ToString();
new_cell_id += top_cell->cell_self_info()->args_type[i]->ToString();
}
MS_LOG(DEBUG) << "Change top cell " << top_cell->cell_id() << " to be dynamic " << new_cell_id;
top_cell->set_cell_id(new_cell_id);
top_cell->set_already_run_cell_id(GetAlreadyRunCellId(new_cell_id));
}
TopCellInfoPtr GradExecutor::ChangeTopCellToDynamicShapeByAuto(const TopCellInfoPtr &top_cell,
const std::vector<ShapeVector> &new_args_shape,
const py::object &cell, const py::args &args) {
@ -3119,15 +3141,7 @@ TopCellInfoPtr GradExecutor::ChangeTopCellToDynamicShapeByAuto(const TopCellInfo
MS_LOG(DEBUG) << "Set dynamic input for auto dynamic shape";
forward()->SetDynamicInput(cell, args);
forward()->SetFeedDynamicInputAbs(cell, args);
// Change cell id
std::string new_cell_id = top_cell->cell_self_info()->cell_self_id;
for (size_t i = 0; i < new_args_shape.size(); ++i) {
new_cell_id += "_" + top_cell->cell_self_info()->args_shape[i]->ToString();
new_cell_id += top_cell->cell_self_info()->args_type[i]->ToString();
}
MS_LOG(DEBUG) << "Change top cell " << top_cell->cell_id() << " to be dynamic " << new_cell_id;
top_cell->set_cell_id(new_cell_id);
top_cell->set_already_run_cell_id(GetAlreadyRunCellId(new_cell_id));
ChangeTopCellInfo(top_cell, new_args_shape);
return top_cell;
}
@ -3155,15 +3169,7 @@ TopCellInfoPtr GradExecutor::ChangeTopCellToDynamicShapeBySetInputs(const TopCel
}
}
}
// Change cell id
std::string new_cell_id = top_cell->cell_self_info()->cell_self_id;
for (size_t i = 0; i < new_args_shape.size(); ++i) {
new_cell_id += "_" + top_cell->cell_self_info()->args_shape[i]->ToString();
new_cell_id += top_cell->cell_self_info()->args_type[i]->ToString();
}
MS_LOG(DEBUG) << "Change top cell " << top_cell->cell_id() << " to be dynamic " << new_cell_id;
top_cell->set_cell_id(new_cell_id);
top_cell->set_already_run_cell_id(GetAlreadyRunCellId(new_cell_id));
ChangeTopCellInfo(top_cell, new_args_shape);
return top_cell;
}

View File

@ -284,6 +284,7 @@ class GradExecutor {
void UpdateForwardTensorInfoInBpropGraph(const string &op_info, const ValuePtr &op_out);
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const;
py::object CheckGraph(const py::object &cell, const py::args &args);
void ChangeTopCellInfo(const TopCellInfoPtr &top_cell, const std::vector<ShapeVector> &new_args_shape);
TopCellInfoPtr ChangeTopCellToDynamicShapeByAuto(const TopCellInfoPtr &top_cell,
const std::vector<ShapeVector> &new_args_shape,
const py::object &cell, const py::args &args);

View File

@ -34,6 +34,34 @@ constexpr char kAttrSorted[] = "sorted";
constexpr char kAttrStrides[] = "strides";
constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask";
bool CheckValueType(const AnfNodePtr &input_node, size_t inputs_num) {
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
if (!value->isa<tensor::Tensor>()) {
MS_EXCEPTION(ValueError) << "The strides of StridedSliceGrad must be a constant." << inputs_num;
}
auto tensor = value->cast<tensor::TensorPtr>();
TypePtr data_type = tensor->Dtype();
MS_EXCEPTION_IF_NULL(data_type);
TypeId type_id = data_type->type_id();
auto element_size = tensor->data().size();
if (type_id == kNumberTypeInt32) {
auto *data = reinterpret_cast<int *>(tensor->data_c());
if ((data[element_size - 1]) != 1) {
return false;
}
} else if (type_id == kNumberTypeInt64) {
auto *data = reinterpret_cast<int64_t *>(tensor->data_c());
if ((data[element_size - 1]) != 1) {
return false;
}
} else {
MS_EXCEPTION(TypeError) << "The strides of StridedSliceGrad must be int.";
}
return true;
}
static bool CheckStridedSlice(const CNodePtr &cnode) {
// check stride[-1] != 1
if (common::AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) {
@ -47,31 +75,11 @@ static bool CheckStridedSlice(const CNodePtr &cnode) {
if (inputs.size() == kInputNum + 1) {
auto input_node = inputs[kInputNum];
MS_EXCEPTION_IF_NULL(input_node);
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
TypePtr data_type = tensor->Dtype();
MS_EXCEPTION_IF_NULL(data_type);
TypeId type_id = data_type->type_id();
auto element_size = tensor->data().size();
if (type_id == kNumberTypeInt32) {
auto *data = reinterpret_cast<int *>(tensor->data_c());
if ((data[element_size - 1]) != 1) {
return false;
}
} else if (type_id == kNumberTypeInt64) {
auto *data = reinterpret_cast<int64_t *>(tensor->data_c());
if ((data[element_size - 1]) != 1) {
return false;
}
} else {
MS_EXCEPTION(TypeError) << "The strides of StridedSliceGrad must be int.";
}
} else {
MS_EXCEPTION(ValueError) << "The strides of StridedSliceGrad must be a constant." << inputs.size();
// Input node can be a cnode, like cast or transdata, which output is a valuenode
if (input_node->isa<CNode>()) {
return true;
}
return CheckValueType(input_node, inputs.size());
}
}