diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc index 84ee13a19ed..9939c223079 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc @@ -25,6 +25,7 @@ #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/debug/trace.h" #include "frontend/optimizer/ad/prim_bprop_optimizer.h" +#include "backend/common/optimizer/helper.h" #include "include/common/utils/convert_utils_py.h" #include "frontend/optimizer/ad/grad.h" #include "pipeline/jit/pass.h" @@ -1448,11 +1449,11 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info, co const auto &obj_id = PyNativeAlgo::Common::GetIdByValue(v); cnode->set_abstract(op_run_info->base_op_run_info.abstract); SaveOutputNodeMap(obj_id, v, cnode); - DoOpGrad(op_run_info, cnode, v); // Dynamic shape should update to top cell if (PyNativeAlgo::Common::IsDynamicShape(op_run_info)) { top_cell()->set_dynamic_shape(true); } + DoOpGrad(op_run_info, cnode, v); } forward()->SetNodeAbsMapByValue(v, op_run_info->base_op_run_info.abstract); UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info, v); @@ -1495,15 +1496,41 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode if (op_run_info->run_in_vm) { input_args = op_run_info->input_value; } else { - // Run in Ms, some op input tensor convert into attributes, so add them back - for (auto &it : op_run_info->index_with_value) { - input_args[it.first] = it.second; - } - // Add other tensor - for (size_t i = 0; i < op_run_info->input_value.size(); ++i) { - if (input_args[i] == nullptr) { - input_args[i] = op_run_info->input_value[i]; + auto convert_tuple_and_scalar_into_tensor = [&](size_t idx, const ValuePtr &default_value) -> bool { + if (top_cell()->dynamic_shape() && + kDynamicInputOpMap.find(op_run_info->base_op_run_info.op_name) != kDynamicInputOpMap.end()) { + const auto &input_vec = kDynamicInputOpMap[op_run_info->base_op_run_info.op_name]; + bool marked = std::any_of(input_vec.begin(), input_vec.end(), [&idx](size_t i) { return idx == i; }); + if (marked) { + if (default_value->isa()) { + MS_LOG(DEBUG) << "Ready to convert tulpe into tensor, op name:" << op_run_info->base_op_run_info.op_name + << ", index:" << idx; + ValueSequencePtr value_seq = default_value->cast(); + ValueTuplePtr value_tuple; + if (value_seq->isa()) { + value_tuple = std::make_shared(value_seq->value()); + } else { + value_tuple = value_seq->cast(); + } + auto tensor_ptr = opt::CreateTupleTensor(value_tuple); + input_args[idx] = tensor_ptr; + return true; + } else if (default_value->isa()) { + MS_LOG(DEBUG) << "Ready to convert scalar into tensor, op name:" << op_run_info->base_op_run_info.op_name + << ", index:" << idx; + auto scalar_tensor = ScalarToTensor(default_value->cast()); + input_args[idx] = scalar_tensor; + return true; + } + } } + return false; + }; + for (size_t i = 0; i < op_run_info->input_value.size(); ++i) { + if (enable_tuple_to_tensor_ && convert_tuple_and_scalar_into_tensor(i, op_run_info->input_value[i])) { + continue; + } + input_args[i] = op_run_info->input_value[i]; } } if (op_run_info->base_op_run_info.has_dynamic_output) { diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.h b/mindspore/ccsrc/pipeline/pynative/grad/grad.h index c3e49754253..0442484a077 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.h @@ -185,6 +185,7 @@ class GradExecutor { bool grad_is_running_{false}; bool need_renormalize_{false}; bool eliminate_forward_{true}; + bool enable_tuple_to_tensor_{false}; int custom_bprop_cell_count_{0}; size_t cell_order_{0}; size_t grad_order_{0}; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_cache.h b/mindspore/ccsrc/pipeline/pynative/pynative_cache.h index 43904eedc56..407856bcc27 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_cache.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_cache.h @@ -28,6 +28,103 @@ namespace mindspore { namespace pynative { +// This map is used to record which input of operator needs to convert to tensor. +// This tensor is passed to grad graph to prevent pass foding optimization. +static std::unordered_map> kDynamicInputOpMap = { + {"AvgPool3DGrad", {0}}, + {"Bernoulli", {1}}, + {"ConjugateTranspose", {1}}, + {"Conv2dBackpropFilter", {2}}, + {"Conv2DTranspose", {2}}, + {"Conv2dBackpropInput", {2}}, + {"Conv3DBackpropFilter", {2}}, + {"Conv3DBackpropInput", {2}}, + {"CropAndResize", {3}}, + {"CTCLossV2", {2, 3}}, + {"CTCLossV2Grad", {2, 3}}, + {"Cumprod", {1}}, + {"CumSum", {1}}, + {"EmbeddingLookup", {2}}, + {"ExpandDims", {1}}, + {"Fill", {1, 2}}, + {"Fills", {1}}, + {"Gather", {2}}, + {"GatherD", {1}}, + {"Greater", {0, 1}}, + {"GreaterEqual", {0, 1}}, + {"IndexFill", {1}}, + {"InvertPermutation", {0}}, + {"Lerp", {2}}, + {"Less", {0, 1}}, + {"LessEqual", {0, 1}}, + {"LinSpace", {2}}, + {"MaskedFill", {2}}, + {"Multinomial", {1}}, + {"NthElement", {1}}, + {"OneHot", {1}}, + {"PadV3Grad", {0}}, + {"Padding", {1}}, + {"ParallelConcat", {0}}, + {"RandomCategorical", {1, 2}}, + {"Poisson", {0}}, + {"ReduceAll", {1}}, + {"ReduceAny", {1}}, + {"ReduceMax", {1}}, + {"ReduceMean", {1}}, + {"ReduceMin", {1}}, + {"ReduceProd", {1}}, + {"ReduceSum", {1}}, + {"Reshape", {1}}, + {"ResizeBilinearV2", {1}}, + {"ScatterNd", {2}}, + {"Slice", {1, 2}}, + {"SliceGrad", {1, 2}}, + {"StandardNormal", {0}}, + {"Tile", {1}}, + {"TopK", {1}}, + {"Transpose", {1}}, + {"TruncateDiv", {0, 1}}, + {"TruncateMod", {0, 1}}, + {"UniformInt", {0}}, + {"UniformReal", {0}}, + {"UnsortedSegmentMax", {2}}, + {"UnsortedSegmentMin", {2}}, + {"UnsortedSegmentProd", {2}}, + {"UnsortedSegmentSum", {2}}, + {"Xdivy", {0, 1}}, + {"Xlogy", {0, 1}}, + {"ScalarToTensor", {0}}, + {"ScalarToArray", {0}}, + {"StandardLaplace", {0}}, + {"UniqueWithPad", {1}}, + {"ApplyAdadelta", {3, 4, 5}}, + {"ApplyAdagrad", {2}}, + {"ApplyAdagradV2", {2}}, + {"ApplyAdaMax", {3, 4, 5, 6, 7}}, + {"ApplyAdamWithAmsgrad", {4, 5, 6}}, + {"ApplyAddSign", {2, 3, 4, 5}}, + {"ApplyCenteredRmsProp", {6, 7, 8}}, + {"ApplyFtrl", {4, 5, 6}}, + {"ApplyGradientDescent", {1}}, + {"ApplyKerasMomentum", {2, 4}}, + {"ApplyMomentum", {2, 4}}, + {"ApplyPowerSign", {2, 3, 4, 5}}, + {"ApplyProximalAdagrad", {2, 3, 4}}, + {"ApplyProximalGradientDescent", {1, 2, 3}}, + {"ApplyRmsProp", {5, 6, 7}}, + {"SparseApplyAdadelta", {3, 4}}, + {"SparseApplyAdagradDA", {5, 6, 7}}, + {"SparseApplyCenteredRMSProp", {4, 5, 6, 7}}, + {"SparseApplyMomentum", {2, 5}}, + {"SparseApplyProximalAdagrad", {2, 3, 4}}, + {"SparseApplyProximalGradientDescent", {1, 2, 3}}, + {"SparseApplyRMSProp", {3}}, + {"SparseTensorDenseAdd", {2}}, + {"SparseTensorDenseMatMul", {2}}, + {"SparseToDense", {3}}, + {"StridedSlice", {2, 3, 4}}, + {"StridedSliceGrad", {2, 3, 4, 5}}}; + // The following structures used to get output abstract of op from cache struct AbsCacheKey { std::string prim_name_;