!41403 Convert tuple and scalar into tensor for bprop in dynamic shape mode.

Merge pull request !41403 from wanghenchang/dynamic_shape_infer
This commit is contained in:
i-robot 2022-09-12 07:30:52 +00:00 committed by Gitee
commit ca371e1531
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 134 additions and 9 deletions

View File

@ -25,6 +25,7 @@
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/debug/trace.h"
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
#include "backend/common/optimizer/helper.h"
#include "include/common/utils/convert_utils_py.h"
#include "frontend/optimizer/ad/grad.h"
#include "pipeline/jit/pass.h"
@ -1448,11 +1449,11 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info, co
const auto &obj_id = PyNativeAlgo::Common::GetIdByValue(v);
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
SaveOutputNodeMap(obj_id, v, cnode);
DoOpGrad(op_run_info, cnode, v);
// Dynamic shape should update to top cell
if (PyNativeAlgo::Common::IsDynamicShape(op_run_info)) {
top_cell()->set_dynamic_shape(true);
}
DoOpGrad(op_run_info, cnode, v);
}
forward()->SetNodeAbsMapByValue(v, op_run_info->base_op_run_info.abstract);
UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info, v);
@ -1495,15 +1496,41 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
if (op_run_info->run_in_vm) {
input_args = op_run_info->input_value;
} else {
// Run in Ms, some op input tensor convert into attributes, so add them back
for (auto &it : op_run_info->index_with_value) {
input_args[it.first] = it.second;
}
// Add other tensor
for (size_t i = 0; i < op_run_info->input_value.size(); ++i) {
if (input_args[i] == nullptr) {
input_args[i] = op_run_info->input_value[i];
auto convert_tuple_and_scalar_into_tensor = [&](size_t idx, const ValuePtr &default_value) -> bool {
if (top_cell()->dynamic_shape() &&
kDynamicInputOpMap.find(op_run_info->base_op_run_info.op_name) != kDynamicInputOpMap.end()) {
const auto &input_vec = kDynamicInputOpMap[op_run_info->base_op_run_info.op_name];
bool marked = std::any_of(input_vec.begin(), input_vec.end(), [&idx](size_t i) { return idx == i; });
if (marked) {
if (default_value->isa<ValueSequence>()) {
MS_LOG(DEBUG) << "Ready to convert tulpe into tensor, op name:" << op_run_info->base_op_run_info.op_name
<< ", index:" << idx;
ValueSequencePtr value_seq = default_value->cast<ValueSequencePtr>();
ValueTuplePtr value_tuple;
if (value_seq->isa<ValueList>()) {
value_tuple = std::make_shared<ValueTuple>(value_seq->value());
} else {
value_tuple = value_seq->cast<ValueTuplePtr>();
}
auto tensor_ptr = opt::CreateTupleTensor(value_tuple);
input_args[idx] = tensor_ptr;
return true;
} else if (default_value->isa<Scalar>()) {
MS_LOG(DEBUG) << "Ready to convert scalar into tensor, op name:" << op_run_info->base_op_run_info.op_name
<< ", index:" << idx;
auto scalar_tensor = ScalarToTensor(default_value->cast<ScalarPtr>());
input_args[idx] = scalar_tensor;
return true;
}
}
}
return false;
};
for (size_t i = 0; i < op_run_info->input_value.size(); ++i) {
if (enable_tuple_to_tensor_ && convert_tuple_and_scalar_into_tensor(i, op_run_info->input_value[i])) {
continue;
}
input_args[i] = op_run_info->input_value[i];
}
}
if (op_run_info->base_op_run_info.has_dynamic_output) {

View File

@ -185,6 +185,7 @@ class GradExecutor {
bool grad_is_running_{false};
bool need_renormalize_{false};
bool eliminate_forward_{true};
bool enable_tuple_to_tensor_{false};
int custom_bprop_cell_count_{0};
size_t cell_order_{0};
size_t grad_order_{0};

View File

@ -28,6 +28,103 @@
namespace mindspore {
namespace pynative {
// This map is used to record which input of operator needs to convert to tensor.
// This tensor is passed to grad graph to prevent pass foding optimization.
static std::unordered_map<std::string, std::vector<size_t>> kDynamicInputOpMap = {
{"AvgPool3DGrad", {0}},
{"Bernoulli", {1}},
{"ConjugateTranspose", {1}},
{"Conv2dBackpropFilter", {2}},
{"Conv2DTranspose", {2}},
{"Conv2dBackpropInput", {2}},
{"Conv3DBackpropFilter", {2}},
{"Conv3DBackpropInput", {2}},
{"CropAndResize", {3}},
{"CTCLossV2", {2, 3}},
{"CTCLossV2Grad", {2, 3}},
{"Cumprod", {1}},
{"CumSum", {1}},
{"EmbeddingLookup", {2}},
{"ExpandDims", {1}},
{"Fill", {1, 2}},
{"Fills", {1}},
{"Gather", {2}},
{"GatherD", {1}},
{"Greater", {0, 1}},
{"GreaterEqual", {0, 1}},
{"IndexFill", {1}},
{"InvertPermutation", {0}},
{"Lerp", {2}},
{"Less", {0, 1}},
{"LessEqual", {0, 1}},
{"LinSpace", {2}},
{"MaskedFill", {2}},
{"Multinomial", {1}},
{"NthElement", {1}},
{"OneHot", {1}},
{"PadV3Grad", {0}},
{"Padding", {1}},
{"ParallelConcat", {0}},
{"RandomCategorical", {1, 2}},
{"Poisson", {0}},
{"ReduceAll", {1}},
{"ReduceAny", {1}},
{"ReduceMax", {1}},
{"ReduceMean", {1}},
{"ReduceMin", {1}},
{"ReduceProd", {1}},
{"ReduceSum", {1}},
{"Reshape", {1}},
{"ResizeBilinearV2", {1}},
{"ScatterNd", {2}},
{"Slice", {1, 2}},
{"SliceGrad", {1, 2}},
{"StandardNormal", {0}},
{"Tile", {1}},
{"TopK", {1}},
{"Transpose", {1}},
{"TruncateDiv", {0, 1}},
{"TruncateMod", {0, 1}},
{"UniformInt", {0}},
{"UniformReal", {0}},
{"UnsortedSegmentMax", {2}},
{"UnsortedSegmentMin", {2}},
{"UnsortedSegmentProd", {2}},
{"UnsortedSegmentSum", {2}},
{"Xdivy", {0, 1}},
{"Xlogy", {0, 1}},
{"ScalarToTensor", {0}},
{"ScalarToArray", {0}},
{"StandardLaplace", {0}},
{"UniqueWithPad", {1}},
{"ApplyAdadelta", {3, 4, 5}},
{"ApplyAdagrad", {2}},
{"ApplyAdagradV2", {2}},
{"ApplyAdaMax", {3, 4, 5, 6, 7}},
{"ApplyAdamWithAmsgrad", {4, 5, 6}},
{"ApplyAddSign", {2, 3, 4, 5}},
{"ApplyCenteredRmsProp", {6, 7, 8}},
{"ApplyFtrl", {4, 5, 6}},
{"ApplyGradientDescent", {1}},
{"ApplyKerasMomentum", {2, 4}},
{"ApplyMomentum", {2, 4}},
{"ApplyPowerSign", {2, 3, 4, 5}},
{"ApplyProximalAdagrad", {2, 3, 4}},
{"ApplyProximalGradientDescent", {1, 2, 3}},
{"ApplyRmsProp", {5, 6, 7}},
{"SparseApplyAdadelta", {3, 4}},
{"SparseApplyAdagradDA", {5, 6, 7}},
{"SparseApplyCenteredRMSProp", {4, 5, 6, 7}},
{"SparseApplyMomentum", {2, 5}},
{"SparseApplyProximalAdagrad", {2, 3, 4}},
{"SparseApplyProximalGradientDescent", {1, 2, 3}},
{"SparseApplyRMSProp", {3}},
{"SparseTensorDenseAdd", {2}},
{"SparseTensorDenseMatMul", {2}},
{"SparseToDense", {3}},
{"StridedSlice", {2, 3, 4}},
{"StridedSliceGrad", {2, 3, 4, 5}}};
// The following structures used to get output abstract of op from cache
struct AbsCacheKey {
std::string prim_name_;