!49710 Add dynamic support for some ops's grad expander.

Merge pull request !49710 from TronZhang/dyn_bprop_adapt_230301_2
This commit is contained in:
i-robot 2023-03-04 03:49:54 +00:00 committed by Gitee
commit c1ef6e0390
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 234 additions and 69 deletions

View File

@ -193,46 +193,27 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
return Emit(prim::kZerosLike, {Tensor(0)});
}
}
if (node->isa<Parameter>()) {
if (node->get()->abstract()->isa<abstract::AbstractNone>()) {
return Emit(prim::kZerosLike, {Tensor(0)});
}
if (node->get()->abstract()->isa<abstract::AbstractTensor>()) {
return Emit(prim::kZerosLike, {node});
}
if (node->get()->abstract()->isa<abstract::AbstractTuple>()) {
NodePtrList list;
auto abstract_tuple = node->get()->abstract()->cast<abstract::AbstractTuplePtr>();
for (auto &e : abstract_tuple->elements()) {
if (e->isa<abstract::AbstractTensor>()) {
auto shape = e->BuildShape()->cast<abstract::ShapePtr>()->shape();
auto type = e->BuildType()->cast<TensorTypePtr>()->element();
list.emplace_back(Emit("Zeros", {EmitValue(MakeValue(shape)), EmitValue(type)}));
} else if (e->isa<abstract::AbstractScalar>()) {
list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())}));
} else {
MS_LOG(WARNING) << "ZerosLike got UNKNOWN TYPE: " << e->ToString();
list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())}));
}
}
return MakeTuple(list);
}
auto v = node->get()->abstract()->BuildValue();
if (v->isa<Scalar>() || v->isa<Type>()) {
return Emit(prim::kZerosLike, {Tensor(0, v->type())});
}
if (v->isa<ValueSequence>()) {
auto sh = GetValue<std::vector<int64_t>>(v);
return Emit(prim::kZerosLike, {Tensor(sh)});
auto abs = node->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTensor>()) {
return Emit(prim::kZerosLike, {node});
} else if (abs->isa<abstract::AbstractMonad>() || abs->isa<abstract::AbstractType>() ||
abs->isa<abstract::AbstractNone>()) {
// To prevent return input directly, and return a operator for latter cnode link.
return node;
} else if (abs->isa<abstract::AbstractSequence>()) {
auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
if (sequence_abs->empty()) {
return node;
}
return Emit(prim::kSequenceZerosLike, {node});
} else if (abs->isa<abstract::AbstractScalar>()) {
return Emit(prim::kZerosLike, {Tensor(0, abs->BuildType())});
}
if (node->get()->abstract()->isa<abstract::AbstractMonad>()) {
return Emit(prim::kZerosLike, {Tensor(0)});
}
if (node->get()->abstract()->isa<abstract::AbstractNone>()) {
return Emit(prim::kZerosLike, {Tensor(0)});
}
return Emit(prim::kZerosLike, {node});
MS_LOG(EXCEPTION) << "Cannot emit ZerosLike for " << node->get()->ToString() << " with abstract " << abs;
}
NodePtr Emitter::Fill(double value, const ShapeVector &shape, TypeId data_type) const {
@ -330,8 +311,8 @@ NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_
}
NodePtrList Emitter::ShapeCalc(const NodePtrList &inputs, const ops::ShapeFunc &shape_func,
const ops::InferFunc &infer_func, const std::vector<int64_t> &value_depend_indices,
size_t size) const {
const ops::InferFunc &infer_func,
const std::vector<int64_t> &value_depend_indices) const {
MS_EXCEPTION_IF_NULL(shape_func);
MS_EXCEPTION_IF_NULL(infer_func);
if (inputs.empty()) {
@ -384,8 +365,13 @@ NodePtrList Emitter::ShapeCalc(const NodePtrList &inputs, const ops::ShapeFunc &
{ops::kAttrValueDependIndices, MakeValue(value_depend_indices)},
{kAttrPrimitiveTarget, MakeValue("CPU")}});
MS_EXCEPTION_IF_NULL(out);
if (size > 1) {
for (size_t i = 0; i < size; ++i) {
MS_EXCEPTION_IF_NULL(out->get());
auto abs = out->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTuple>()) {
auto abstract_tuple = abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
for (size_t i = 0; i < abstract_tuple->size(); ++i) {
res.push_back(TupleGetItem(out, i));
}
} else {

View File

@ -173,7 +173,7 @@ class COMMON_EXPORT Emitter {
/// \param[in] size The size of outputs.
/// \return NodePtrList, the outputs shape list.
NodePtrList ShapeCalc(const NodePtrList &inputs, const ops::ShapeFunc &shape_func, const ops::InferFunc &infer_func,
const std::vector<int64_t> &value_depend_indices = {}, size_t size = 1) const;
const std::vector<int64_t> &value_depend_indices = {}) const;
using BlockFunc = std::function<NodePtrList(const Emitter *)>;
/// \brief Generate a conditional block.

View File

@ -325,6 +325,13 @@ void RegArrayBpropExpanderOps1() {
REGISTER_EXPANDER_BPROP_IMPL(Triu);
REGISTER_EXPANDER_BPROP_IMPL(ZerosLike);
REGISTER_EXPANDER_BPROP_IMPL(Concat);
REGISTER_EXPANDER_BPROP_IMPL(Reshape);
REGISTER_EXPANDER_BPROP_IMPL(ExpandDims);
REGISTER_EXPANDER_BPROP_IMPL(Gather);
REGISTER_EXPANDER_BPROP_IMPL(GatherV2);
REGISTER_EXPANDER_BPROP_IMPL(Tile);
REGISTER_EXPANDER_BPROP_IMPL(StridedSlice);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentSum);
}
void RegArrayBpropExpanderOps2() {}

View File

@ -38,13 +38,44 @@ NodePtrList GatherDropNegatives(const BpropIRBuilder *ib, const NodePtr &params,
if (is_positive_param == nullptr) {
is_positive = ib->GreaterEqual(ids, ib->Tensor(0, ib->GetDtype(ids)));
auto broadcastable_shape = ib->GetShape(is_positive);
auto back_size = ib->GetShape(gathered).size() - ib->GetShape(is_positive).size();
for (size_t i = 0; i < back_size; ++i) {
broadcastable_shape.push_back(1);
}
is_positive = ib->Reshape(is_positive, broadcastable_shape);
auto gathered_shape = ib->GetShape(gathered);
is_positive = ib->LogicalAnd(is_positive, ib->Fill(1.0, gathered_shape, TypeId::kNumberTypeBool));
if (IsDynamic(broadcastable_shape) || IsDynamic(gathered_shape)) {
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
auto is_pos = inputs.at(1);
auto gather_rank = inputs.at(0).size();
auto is_pos_rank = is_pos.size();
std::vector<int64_t> res_shape(is_pos.begin(), is_pos.end());
if (gather_rank > is_pos_rank) {
auto expand_len = gather_rank - is_pos_rank;
for (size_t i = 0; i < expand_len; ++i) {
res_shape.push_back(1);
}
}
return {res_shape};
};
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) -> ShapeVector {
if (!invalid_indices.empty()) {
return {-1};
}
auto gather_rank = inputs.at(0).size();
auto is_pos_rank = inputs.at(1).size();
return {SizeToLong(std::max(gather_rank, is_pos_rank))};
};
auto is_positive_shape = ib->ShapeCalc({gathered, is_positive}, shape_func, infer_func, {})[0];
is_positive = ib->Reshape(is_positive, is_positive_shape);
auto shape_gather = ib->Emit("TupleToTensor", {ib->Emit("Shape", {gathered}), ib->Value(kInt64)});
is_positive = ib->LogicalAnd(is_positive, ib->Fill(1.0, shape_gather, TypeId::kNumberTypeBool));
} else {
auto back_size = ib->GetShape(gathered).size() - ib->GetShape(is_positive).size();
for (size_t i = 0; i < back_size; ++i) {
broadcastable_shape.push_back(1);
}
is_positive = ib->Reshape(is_positive, broadcastable_shape);
auto ones = ib->Fill(1.0, gathered_shape, TypeId::kNumberTypeBool);
is_positive = ib->LogicalAnd(is_positive, ones);
}
}
auto zero_slice = ib->ZerosLike(gathered);
return {ib->Select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive};
@ -117,19 +148,117 @@ NodePtrList TensorScatterPossibleReplacement(const BpropIRBuilder *ib) {
return {ib->Cast(dx, ib->GetDtype(x)), ib->ZerosLike(indices), ib->Cast(dupdates, ib->GetDtype(updates))};
}
ShapeArray RegenerateOutputShapeFunc(const ShapeArray &inputs) {
auto x_shape = inputs.at(0);
auto indices_shape = inputs.at(1);
auto axis = inputs.at(2);
MS_EXCEPTION_IF_CHECK_FAIL(axis.size() == 1, "axis should be a scalar.");
auto axis_value = axis[0];
if (axis_value < 0) {
axis_value += x_shape.size();
}
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.begin() + axis_value);
out_shape.insert(out_shape.end(), indices_shape.begin(), indices_shape.end());
out_shape.insert(out_shape.end(), x_shape.begin() + axis_value + 1, x_shape.end());
return {out_shape};
}
ShapeVector RegenerateOutputInferFunc(const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) {
if (!invalid_indices.empty()) {
return {-1};
}
auto x_rank = inputs.at(0).size();
auto indices_rank = inputs.at(1).size();
return {static_cast<int64_t>(x_rank + indices_rank)};
}
ShapeArray Perm1ShapeFunc(const ShapeArray &inputs) {
auto dout_rank = SizeToLong(inputs.at(0).size());
auto indices_rank = SizeToLong(inputs.at(1).size());
auto axis = inputs.at(2);
MS_EXCEPTION_IF_CHECK_FAIL(axis.size() == 1, "axis should be a scalar.");
auto axis_value = axis[0];
if (axis_value < 0) {
axis_value += dout_rank - indices_rank + 1;
}
std::vector<int64_t> perm(indices_rank);
std::iota(perm.begin(), perm.end(), axis_value);
int64_t index_end = std::min(dout_rank, axis_value);
for (int64_t i = 0; i < index_end; ++i) {
perm.push_back(i);
}
for (int64_t i = axis_value + indices_rank; i < dout_rank; ++i) {
perm.push_back(i);
}
return {perm};
}
ShapeVector PermInferFunc(const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) {
if (!invalid_indices.empty()) {
return {-1};
}
return {static_cast<int64_t>(inputs.at(0).size())};
}
ShapeArray Perm2ShapeFunc(const ShapeArray &inputs) {
auto x_rank = SizeToLong(inputs.at(0).size());
auto axis = inputs.at(1);
MS_EXCEPTION_IF_CHECK_FAIL(axis.size() == 1, "axis should be a scalar.");
auto axis_value = axis[0];
if (axis_value < 0) {
axis_value += x_rank;
}
std::vector<int64_t> perm(axis_value);
std::iota(perm.begin(), perm.end(), 1);
perm.push_back(0);
for (int64_t i = 1 + SizeToLong(axis_value); i < x_rank; ++i) {
perm.push_back(i);
}
return {perm};
}
NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
auto x = ib->GetInput(kIndex0);
auto indices = ib->GetInput(kIndex1);
auto axis = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4);
auto orig_indices = indices;
auto x_shp = ib->GetShape(x);
auto out_shp = ib->GetShape(dout);
auto ind_shp = ib->GetShape(indices);
auto axis_v = CheckRange(GetIntValue(axis), SizeToLong(x_shp.size()));
if (out_shp.empty()) {
dout = ib->Emit("ExpandDims", {dout, ib->Tensor(-1)});
}
if (IsDynamic(x_shp) || IsDynamic(ind_shp) || IsDynamic(out_shp)) {
if (ind_shp.empty()) {
indices = ib->Emit("ExpandDims", {indices, ib->Tensor(-1)});
auto out_shp1 = ib->ShapeCalc({x, indices, axis}, RegenerateOutputShapeFunc, RegenerateOutputInferFunc, {2})[0];
dout = ib->Reshape(dout, out_shp1);
}
// Calculate perm.
auto perm_1 = ib->ShapeCalc({dout, indices, axis}, Perm1ShapeFunc, PermInferFunc, {2})[0];
auto perm_2 = ib->ShapeCalc({x, axis}, Perm2ShapeFunc, PermInferFunc, {1})[0];
auto values_transpose = ib->Transpose(dout, perm_1);
auto x_grad = ib->Emit("UnsortedSegmentSum",
{values_transpose, indices, ib->Emit("TupleGetItem", {ib->Emit("Shape", {x}), axis})});
x_grad = ib->Transpose(x_grad, perm_2);
return {x_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
}
auto axis_v = CheckRange(GetIntValue(axis), SizeToLong(x_shp.size()));
if (ind_shp.empty()) {
indices = ib->Emit("ExpandDims", {indices, ib->Tensor(-1)});
ind_shp = ib->GetShape(indices);
@ -142,7 +271,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
auto tmp = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_v])});
auto perm_2 = GenerateInverseIndex(x_shp, axis_v);
auto params_grad = ib->Transpose(tmp, perm_2);
return {params_grad, ib->ZerosLike(orig_indices), ib->ZerosLike(axis)};
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
}
int64_t RecorrectAxis(int64_t axis, size_t rank) {
@ -433,8 +562,15 @@ REG_BPROP_BUILDER("StridedSlice").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib)
auto end = ib->GetInput(kIndex2);
auto strides = ib->GetInput(kIndex3);
auto dout = ib->GetInput(kIndex5);
auto x_shape = ib->EmitValue(MakeValue(ib->GetShape(x)));
auto dx = ib->Emit("StridedSliceGrad", {dout, x_shape, begin, end, strides},
auto x_shape_vec = ib->GetShape(x);
NodePtr x_shape_node;
if (IsDynamic(x_shape_vec)) {
x_shape_node = ib->Emit("TensorShape", {x});
} else {
x_shape_node = ib->EmitValue(MakeValue(x_shape_vec));
}
auto dx = ib->Emit("StridedSliceGrad", {dout, x_shape_node, begin, end, strides},
{{"begin_mask", ib->GetAttr("begin_mask")},
{"end_mask", ib->GetAttr("end_mask")},
{"ellipsis_mask", ib->GetAttr("ellipsis_mask")},
@ -546,8 +682,14 @@ REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib)
auto x = ib->GetInput(kIndex0);
auto shp = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
auto shapex = ib->GetShape(x);
return {ib->Reshape(dout, shapex), ib->ZerosLike(shp)};
auto shape_x = ib->GetShape(x);
NodePtr dx;
if (!IsDynamic(shape_x)) {
dx = ib->Reshape(dout, shape_x);
} else {
dx = ib->Reshape(dout, ib->Emit("TensorShape", {x}));
}
return {dx, ib->ZerosLike(shp)};
});
REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
@ -667,11 +809,11 @@ REG_BPROP_BUILDER("Concat").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
for (size_t i = 0; i < input_nums; ++i) {
x_tuple.push_back(ib->TupleGetItem(x, i));
}
auto concat_offset = ib->ShapeCalc(x_tuple, shape_func, infer_func, {}, input_nums);
auto concat_offset = ib->ShapeCalc(x_tuple, shape_func, infer_func, {});
NodePtrList res;
for (size_t i = 0; i < input_nums; ++i) {
auto input = ib->Emit("Shape", {ib->TupleGetItem(x, i)});
auto slice_out = ib->Emit(kSliceOpName, {dout, concat_offset[i], input});
auto slice_out = ib->Emit(kSliceOpName, {dout, concat_offset.at(i), input});
res.push_back(slice_out);
}
@ -912,8 +1054,14 @@ REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(i
auto x = ib->GetInput(kIndex0);
auto axis = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
auto shapex = ib->GetShape(x);
return {ib->Reshape(dout, shapex), ib->ZerosLike(axis)};
auto shape_x = ib->GetShape(x);
NodePtr dx;
if (IsDynamic(shape_x)) {
dx = ib->Reshape(dout, ib->Emit("TensorShape", {x}));
} else {
dx = ib->Reshape(dout, shape_x);
}
return {dx, ib->ZerosLike(axis)};
});
REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
@ -963,10 +1111,33 @@ REG_BPROP_BUILDER("Tile").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto input_multiples = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
auto shapex = ib->GetShape(x);
auto multiples = GetIntList(input_multiples);
auto r_shape = TileShape(multiples, shapex);
auto axis = Range(0, static_cast<int64_t>(r_shape.size()), 2);
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
// {x_shape, multiples}
auto r_shape = TileShape(inputs.at(1), inputs.at(0));
ShapeVector axis;
size_t axis_sz = r_shape.size() / 2;
axis.reserve(axis_sz);
for (int64_t i = 0; i < static_cast<int64_t>(axis_sz); ++i) {
axis.push_back(i * 2);
}
return {r_shape, axis};
};
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) -> ShapeVector {
if (!invalid_indices.empty()) {
return {-1, -1};
}
auto x_sz = static_cast<int64_t>(inputs.at(0).size());
auto multiples_sz = static_cast<int64_t>(inputs.at(1).size());
auto max_sz = x_sz > multiples_sz ? x_sz : multiples_sz;
return {2 * max_sz, max_sz};
};
auto calc_res = ib->ShapeCalc({x, input_multiples}, shape_func, infer_func, {1});
auto r_shape = calc_res[0];
auto axis = calc_res[1];
auto dout_reshaped = ib->Reshape(dout, r_shape);
auto dout_dtype = ib->GetDtype(dout_reshaped)->type_id();
NodePtr dx;
@ -974,15 +1145,16 @@ REG_BPROP_BUILDER("Tile").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
if (need_reduce.first) {
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
dout_reshaped = ib->Cast(dout_reshaped, kFloat32);
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
dx = ib->Emit("ReduceSum", {dout_reshaped, axis}, {{"keep_dims", MakeValue(false)}});
dx = ib->Cast(dx, dout_dtype);
} else {
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
dx = ib->Emit("ReduceSum", {dout_reshaped, axis}, {{"keep_dims", MakeValue(false)}});
}
} else {
dx = ib->Reshape(dout_reshaped, need_reduce.second);
}
dx = ib->Reshape(dx, shapex);
auto shape_x = ib->Emit("Shape", {x});
dx = ib->Reshape(dx, shape_x);
return {dx, ib->ZerosLike(input_multiples)};
});

View File

@ -90,8 +90,7 @@ bool ShapeCalcCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs
}
std::vector<KernelAttr> ShapeCalcCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64)}};
static std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
return support_list;
}

View File

@ -72,6 +72,7 @@ int64_t GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abs
auto value = input_args[kInputIndex2]->BuildValue();
if (!IsValueKnown(value)) {
num_segments_v = abstract::Shape::kShapeDimAny;
return num_segments_v;
}
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();