commit
9dd9425121
|
@ -175,6 +175,7 @@ AnfNodePtr BuildOnesLikeValue(const FuncGraphPtr &tape, const ValuePtr &out, con
|
|||
// Build ones_like(out) as dout, shape is same with out.sens_value its id hold by pynative execute, which can be
|
||||
// replace forward, but out is not.
|
||||
if (ValueHasDynamicShape(out)) {
|
||||
MS_EXCEPTION_IF_NULL(sens_value);
|
||||
auto value_node = NewValueNode(sens_value);
|
||||
auto value_node_abs = sens_value->ToAbstract()->Broaden();
|
||||
MS_LOG(DEBUG) << "Sens value abstract " << value_node_abs->ToString();
|
||||
|
|
|
@ -2685,6 +2685,7 @@ void ForwardExecutor::SetFeedDynamicInputAbs(const py::object &cell, const py::a
|
|||
<< PyObjToValue(args[i])->ToAbstract()->Broaden()->ToString()
|
||||
<< ", dynamic abs: " << abs->ToString();
|
||||
dynamic_shape_info_ptr()->obj_id_with_dynamic_output_abs[arg_id] = abs;
|
||||
(void)node_abs_map_.erase(arg_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3001,10 +3002,19 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
|
|||
auto param_i = only_tensors[i];
|
||||
const auto ¶m_i_value = PyObjToValue(param_i);
|
||||
input_param_values.emplace_back(param_i_value);
|
||||
auto param_i_abs = param_i_value->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(param_i_abs);
|
||||
new_param->set_abstract(param_i_abs->Broaden());
|
||||
const auto ¶m_i_id = GetId(param_i);
|
||||
abstract::AbstractBasePtr param_i_abs = nullptr;
|
||||
auto item = forward()->dynamic_shape_info_ptr()->obj_id_with_dynamic_output_abs.find(param_i_id);
|
||||
if (item != forward()->dynamic_shape_info_ptr()->obj_id_with_dynamic_output_abs.end()) {
|
||||
MS_LOG(DEBUG) << "Param " << i << " is dynamic input";
|
||||
param_i_abs = item->second;
|
||||
} else {
|
||||
param_i_abs = param_i_value->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(param_i_abs);
|
||||
param_i_abs = param_i_abs->Broaden();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(param_i_abs);
|
||||
new_param->set_abstract(param_i_abs);
|
||||
SetTupleArgsToGraphInfoMap(curr_g(), param_i, new_param, true);
|
||||
SetNodeMapInGraphInfoMap(curr_g(), param_i_id, new_param);
|
||||
SetParamNodeMapInGraphInfoMap(curr_g(), param_i_id, new_param);
|
||||
|
@ -3102,6 +3112,18 @@ void GradExecutor::NewGraphInner(const py::object *ret, const py::object &cell,
|
|||
}
|
||||
}
|
||||
|
||||
void GradExecutor::ChangeTopCellInfo(const TopCellInfoPtr &top_cell, const std::vector<ShapeVector> &new_args_shape) {
|
||||
MS_EXCEPTION_IF_NULL(top_cell);
|
||||
std::string new_cell_id = top_cell->cell_self_info()->cell_self_id;
|
||||
for (size_t i = 0; i < new_args_shape.size(); ++i) {
|
||||
new_cell_id += "_" + top_cell->cell_self_info()->args_shape[i]->ToString();
|
||||
new_cell_id += top_cell->cell_self_info()->args_type[i]->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Change top cell " << top_cell->cell_id() << " to be dynamic " << new_cell_id;
|
||||
top_cell->set_cell_id(new_cell_id);
|
||||
top_cell->set_already_run_cell_id(GetAlreadyRunCellId(new_cell_id));
|
||||
}
|
||||
|
||||
TopCellInfoPtr GradExecutor::ChangeTopCellToDynamicShapeByAuto(const TopCellInfoPtr &top_cell,
|
||||
const std::vector<ShapeVector> &new_args_shape,
|
||||
const py::object &cell, const py::args &args) {
|
||||
|
@ -3119,15 +3141,7 @@ TopCellInfoPtr GradExecutor::ChangeTopCellToDynamicShapeByAuto(const TopCellInfo
|
|||
MS_LOG(DEBUG) << "Set dynamic input for auto dynamic shape";
|
||||
forward()->SetDynamicInput(cell, args);
|
||||
forward()->SetFeedDynamicInputAbs(cell, args);
|
||||
// Change cell id
|
||||
std::string new_cell_id = top_cell->cell_self_info()->cell_self_id;
|
||||
for (size_t i = 0; i < new_args_shape.size(); ++i) {
|
||||
new_cell_id += "_" + top_cell->cell_self_info()->args_shape[i]->ToString();
|
||||
new_cell_id += top_cell->cell_self_info()->args_type[i]->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Change top cell " << top_cell->cell_id() << " to be dynamic " << new_cell_id;
|
||||
top_cell->set_cell_id(new_cell_id);
|
||||
top_cell->set_already_run_cell_id(GetAlreadyRunCellId(new_cell_id));
|
||||
ChangeTopCellInfo(top_cell, new_args_shape);
|
||||
return top_cell;
|
||||
}
|
||||
|
||||
|
@ -3155,15 +3169,7 @@ TopCellInfoPtr GradExecutor::ChangeTopCellToDynamicShapeBySetInputs(const TopCel
|
|||
}
|
||||
}
|
||||
}
|
||||
// Change cell id
|
||||
std::string new_cell_id = top_cell->cell_self_info()->cell_self_id;
|
||||
for (size_t i = 0; i < new_args_shape.size(); ++i) {
|
||||
new_cell_id += "_" + top_cell->cell_self_info()->args_shape[i]->ToString();
|
||||
new_cell_id += top_cell->cell_self_info()->args_type[i]->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Change top cell " << top_cell->cell_id() << " to be dynamic " << new_cell_id;
|
||||
top_cell->set_cell_id(new_cell_id);
|
||||
top_cell->set_already_run_cell_id(GetAlreadyRunCellId(new_cell_id));
|
||||
ChangeTopCellInfo(top_cell, new_args_shape);
|
||||
return top_cell;
|
||||
}
|
||||
|
||||
|
|
|
@ -284,6 +284,7 @@ class GradExecutor {
|
|||
void UpdateForwardTensorInfoInBpropGraph(const string &op_info, const ValuePtr &op_out);
|
||||
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const;
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
void ChangeTopCellInfo(const TopCellInfoPtr &top_cell, const std::vector<ShapeVector> &new_args_shape);
|
||||
TopCellInfoPtr ChangeTopCellToDynamicShapeByAuto(const TopCellInfoPtr &top_cell,
|
||||
const std::vector<ShapeVector> &new_args_shape,
|
||||
const py::object &cell, const py::args &args);
|
||||
|
|
|
@ -34,6 +34,34 @@ constexpr char kAttrSorted[] = "sorted";
|
|||
constexpr char kAttrStrides[] = "strides";
|
||||
constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask";
|
||||
|
||||
bool CheckValueType(const AnfNodePtr &input_node, size_t inputs_num) {
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
if (!value->isa<tensor::Tensor>()) {
|
||||
MS_EXCEPTION(ValueError) << "The strides of StridedSliceGrad must be a constant." << inputs_num;
|
||||
}
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
TypePtr data_type = tensor->Dtype();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
TypeId type_id = data_type->type_id();
|
||||
auto element_size = tensor->data().size();
|
||||
if (type_id == kNumberTypeInt32) {
|
||||
auto *data = reinterpret_cast<int *>(tensor->data_c());
|
||||
if ((data[element_size - 1]) != 1) {
|
||||
return false;
|
||||
}
|
||||
} else if (type_id == kNumberTypeInt64) {
|
||||
auto *data = reinterpret_cast<int64_t *>(tensor->data_c());
|
||||
if ((data[element_size - 1]) != 1) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "The strides of StridedSliceGrad must be int.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckStridedSlice(const CNodePtr &cnode) {
|
||||
// check stride[-1] != 1
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) {
|
||||
|
@ -47,31 +75,11 @@ static bool CheckStridedSlice(const CNodePtr &cnode) {
|
|||
if (inputs.size() == kInputNum + 1) {
|
||||
auto input_node = inputs[kInputNum];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
TypePtr data_type = tensor->Dtype();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
TypeId type_id = data_type->type_id();
|
||||
auto element_size = tensor->data().size();
|
||||
if (type_id == kNumberTypeInt32) {
|
||||
auto *data = reinterpret_cast<int *>(tensor->data_c());
|
||||
if ((data[element_size - 1]) != 1) {
|
||||
return false;
|
||||
}
|
||||
} else if (type_id == kNumberTypeInt64) {
|
||||
auto *data = reinterpret_cast<int64_t *>(tensor->data_c());
|
||||
if ((data[element_size - 1]) != 1) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "The strides of StridedSliceGrad must be int.";
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "The strides of StridedSliceGrad must be a constant." << inputs.size();
|
||||
// Input node can be a cnode, like cast or transdata, which output is a valuenode
|
||||
if (input_node->isa<CNode>()) {
|
||||
return true;
|
||||
}
|
||||
return CheckValueType(input_node, inputs.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue