!41403 Convert tuple and scalar into tensor for bprop in dynamic shape mode.
Merge pull request !41403 from wanghenchang/dynamic_shape_infer
This commit is contained in:
commit
ca371e1531
|
@ -25,6 +25,7 @@
|
|||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/debug/trace.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "frontend/optimizer/ad/grad.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
|
@ -1448,11 +1449,11 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info, co
|
|||
const auto &obj_id = PyNativeAlgo::Common::GetIdByValue(v);
|
||||
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
|
||||
SaveOutputNodeMap(obj_id, v, cnode);
|
||||
DoOpGrad(op_run_info, cnode, v);
|
||||
// Dynamic shape should update to top cell
|
||||
if (PyNativeAlgo::Common::IsDynamicShape(op_run_info)) {
|
||||
top_cell()->set_dynamic_shape(true);
|
||||
}
|
||||
DoOpGrad(op_run_info, cnode, v);
|
||||
}
|
||||
forward()->SetNodeAbsMapByValue(v, op_run_info->base_op_run_info.abstract);
|
||||
UpdateForwardTensorInfoInBpropGraph(op_run_info->op_info, v);
|
||||
|
@ -1495,15 +1496,41 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
|
|||
if (op_run_info->run_in_vm) {
|
||||
input_args = op_run_info->input_value;
|
||||
} else {
|
||||
// Run in Ms, some op input tensor convert into attributes, so add them back
|
||||
for (auto &it : op_run_info->index_with_value) {
|
||||
input_args[it.first] = it.second;
|
||||
}
|
||||
// Add other tensor
|
||||
for (size_t i = 0; i < op_run_info->input_value.size(); ++i) {
|
||||
if (input_args[i] == nullptr) {
|
||||
input_args[i] = op_run_info->input_value[i];
|
||||
auto convert_tuple_and_scalar_into_tensor = [&](size_t idx, const ValuePtr &default_value) -> bool {
|
||||
if (top_cell()->dynamic_shape() &&
|
||||
kDynamicInputOpMap.find(op_run_info->base_op_run_info.op_name) != kDynamicInputOpMap.end()) {
|
||||
const auto &input_vec = kDynamicInputOpMap[op_run_info->base_op_run_info.op_name];
|
||||
bool marked = std::any_of(input_vec.begin(), input_vec.end(), [&idx](size_t i) { return idx == i; });
|
||||
if (marked) {
|
||||
if (default_value->isa<ValueSequence>()) {
|
||||
MS_LOG(DEBUG) << "Ready to convert tulpe into tensor, op name:" << op_run_info->base_op_run_info.op_name
|
||||
<< ", index:" << idx;
|
||||
ValueSequencePtr value_seq = default_value->cast<ValueSequencePtr>();
|
||||
ValueTuplePtr value_tuple;
|
||||
if (value_seq->isa<ValueList>()) {
|
||||
value_tuple = std::make_shared<ValueTuple>(value_seq->value());
|
||||
} else {
|
||||
value_tuple = value_seq->cast<ValueTuplePtr>();
|
||||
}
|
||||
auto tensor_ptr = opt::CreateTupleTensor(value_tuple);
|
||||
input_args[idx] = tensor_ptr;
|
||||
return true;
|
||||
} else if (default_value->isa<Scalar>()) {
|
||||
MS_LOG(DEBUG) << "Ready to convert scalar into tensor, op name:" << op_run_info->base_op_run_info.op_name
|
||||
<< ", index:" << idx;
|
||||
auto scalar_tensor = ScalarToTensor(default_value->cast<ScalarPtr>());
|
||||
input_args[idx] = scalar_tensor;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (size_t i = 0; i < op_run_info->input_value.size(); ++i) {
|
||||
if (enable_tuple_to_tensor_ && convert_tuple_and_scalar_into_tensor(i, op_run_info->input_value[i])) {
|
||||
continue;
|
||||
}
|
||||
input_args[i] = op_run_info->input_value[i];
|
||||
}
|
||||
}
|
||||
if (op_run_info->base_op_run_info.has_dynamic_output) {
|
||||
|
|
|
@ -185,6 +185,7 @@ class GradExecutor {
|
|||
bool grad_is_running_{false};
|
||||
bool need_renormalize_{false};
|
||||
bool eliminate_forward_{true};
|
||||
bool enable_tuple_to_tensor_{false};
|
||||
int custom_bprop_cell_count_{0};
|
||||
size_t cell_order_{0};
|
||||
size_t grad_order_{0};
|
||||
|
|
|
@ -28,6 +28,103 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
// This map is used to record which input of operator needs to convert to tensor.
|
||||
// This tensor is passed to grad graph to prevent pass foding optimization.
|
||||
static std::unordered_map<std::string, std::vector<size_t>> kDynamicInputOpMap = {
|
||||
{"AvgPool3DGrad", {0}},
|
||||
{"Bernoulli", {1}},
|
||||
{"ConjugateTranspose", {1}},
|
||||
{"Conv2dBackpropFilter", {2}},
|
||||
{"Conv2DTranspose", {2}},
|
||||
{"Conv2dBackpropInput", {2}},
|
||||
{"Conv3DBackpropFilter", {2}},
|
||||
{"Conv3DBackpropInput", {2}},
|
||||
{"CropAndResize", {3}},
|
||||
{"CTCLossV2", {2, 3}},
|
||||
{"CTCLossV2Grad", {2, 3}},
|
||||
{"Cumprod", {1}},
|
||||
{"CumSum", {1}},
|
||||
{"EmbeddingLookup", {2}},
|
||||
{"ExpandDims", {1}},
|
||||
{"Fill", {1, 2}},
|
||||
{"Fills", {1}},
|
||||
{"Gather", {2}},
|
||||
{"GatherD", {1}},
|
||||
{"Greater", {0, 1}},
|
||||
{"GreaterEqual", {0, 1}},
|
||||
{"IndexFill", {1}},
|
||||
{"InvertPermutation", {0}},
|
||||
{"Lerp", {2}},
|
||||
{"Less", {0, 1}},
|
||||
{"LessEqual", {0, 1}},
|
||||
{"LinSpace", {2}},
|
||||
{"MaskedFill", {2}},
|
||||
{"Multinomial", {1}},
|
||||
{"NthElement", {1}},
|
||||
{"OneHot", {1}},
|
||||
{"PadV3Grad", {0}},
|
||||
{"Padding", {1}},
|
||||
{"ParallelConcat", {0}},
|
||||
{"RandomCategorical", {1, 2}},
|
||||
{"Poisson", {0}},
|
||||
{"ReduceAll", {1}},
|
||||
{"ReduceAny", {1}},
|
||||
{"ReduceMax", {1}},
|
||||
{"ReduceMean", {1}},
|
||||
{"ReduceMin", {1}},
|
||||
{"ReduceProd", {1}},
|
||||
{"ReduceSum", {1}},
|
||||
{"Reshape", {1}},
|
||||
{"ResizeBilinearV2", {1}},
|
||||
{"ScatterNd", {2}},
|
||||
{"Slice", {1, 2}},
|
||||
{"SliceGrad", {1, 2}},
|
||||
{"StandardNormal", {0}},
|
||||
{"Tile", {1}},
|
||||
{"TopK", {1}},
|
||||
{"Transpose", {1}},
|
||||
{"TruncateDiv", {0, 1}},
|
||||
{"TruncateMod", {0, 1}},
|
||||
{"UniformInt", {0}},
|
||||
{"UniformReal", {0}},
|
||||
{"UnsortedSegmentMax", {2}},
|
||||
{"UnsortedSegmentMin", {2}},
|
||||
{"UnsortedSegmentProd", {2}},
|
||||
{"UnsortedSegmentSum", {2}},
|
||||
{"Xdivy", {0, 1}},
|
||||
{"Xlogy", {0, 1}},
|
||||
{"ScalarToTensor", {0}},
|
||||
{"ScalarToArray", {0}},
|
||||
{"StandardLaplace", {0}},
|
||||
{"UniqueWithPad", {1}},
|
||||
{"ApplyAdadelta", {3, 4, 5}},
|
||||
{"ApplyAdagrad", {2}},
|
||||
{"ApplyAdagradV2", {2}},
|
||||
{"ApplyAdaMax", {3, 4, 5, 6, 7}},
|
||||
{"ApplyAdamWithAmsgrad", {4, 5, 6}},
|
||||
{"ApplyAddSign", {2, 3, 4, 5}},
|
||||
{"ApplyCenteredRmsProp", {6, 7, 8}},
|
||||
{"ApplyFtrl", {4, 5, 6}},
|
||||
{"ApplyGradientDescent", {1}},
|
||||
{"ApplyKerasMomentum", {2, 4}},
|
||||
{"ApplyMomentum", {2, 4}},
|
||||
{"ApplyPowerSign", {2, 3, 4, 5}},
|
||||
{"ApplyProximalAdagrad", {2, 3, 4}},
|
||||
{"ApplyProximalGradientDescent", {1, 2, 3}},
|
||||
{"ApplyRmsProp", {5, 6, 7}},
|
||||
{"SparseApplyAdadelta", {3, 4}},
|
||||
{"SparseApplyAdagradDA", {5, 6, 7}},
|
||||
{"SparseApplyCenteredRMSProp", {4, 5, 6, 7}},
|
||||
{"SparseApplyMomentum", {2, 5}},
|
||||
{"SparseApplyProximalAdagrad", {2, 3, 4}},
|
||||
{"SparseApplyProximalGradientDescent", {1, 2, 3}},
|
||||
{"SparseApplyRMSProp", {3}},
|
||||
{"SparseTensorDenseAdd", {2}},
|
||||
{"SparseTensorDenseMatMul", {2}},
|
||||
{"SparseToDense", {3}},
|
||||
{"StridedSlice", {2, 3, 4}},
|
||||
{"StridedSliceGrad", {2, 3, 4, 5}}};
|
||||
|
||||
// The following structures used to get output abstract of op from cache
|
||||
struct AbsCacheKey {
|
||||
std::string prim_name_;
|
||||
|
|
Loading…
Reference in New Issue