forked from mindspore-Ecosystem/mindspore
!49710 Add dynamic support for some ops's grad expander.
Merge pull request !49710 from TronZhang/dyn_bprop_adapt_230301_2
This commit is contained in:
commit
c1ef6e0390
|
@ -193,46 +193,27 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
|
|||
return Emit(prim::kZerosLike, {Tensor(0)});
|
||||
}
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
if (node->get()->abstract()->isa<abstract::AbstractNone>()) {
|
||||
return Emit(prim::kZerosLike, {Tensor(0)});
|
||||
}
|
||||
if (node->get()->abstract()->isa<abstract::AbstractTensor>()) {
|
||||
return Emit(prim::kZerosLike, {node});
|
||||
}
|
||||
if (node->get()->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
NodePtrList list;
|
||||
auto abstract_tuple = node->get()->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
for (auto &e : abstract_tuple->elements()) {
|
||||
if (e->isa<abstract::AbstractTensor>()) {
|
||||
auto shape = e->BuildShape()->cast<abstract::ShapePtr>()->shape();
|
||||
auto type = e->BuildType()->cast<TensorTypePtr>()->element();
|
||||
list.emplace_back(Emit("Zeros", {EmitValue(MakeValue(shape)), EmitValue(type)}));
|
||||
} else if (e->isa<abstract::AbstractScalar>()) {
|
||||
list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())}));
|
||||
} else {
|
||||
MS_LOG(WARNING) << "ZerosLike got UNKNOWN TYPE: " << e->ToString();
|
||||
list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())}));
|
||||
}
|
||||
}
|
||||
return MakeTuple(list);
|
||||
}
|
||||
auto v = node->get()->abstract()->BuildValue();
|
||||
if (v->isa<Scalar>() || v->isa<Type>()) {
|
||||
return Emit(prim::kZerosLike, {Tensor(0, v->type())});
|
||||
}
|
||||
if (v->isa<ValueSequence>()) {
|
||||
auto sh = GetValue<std::vector<int64_t>>(v);
|
||||
return Emit(prim::kZerosLike, {Tensor(sh)});
|
||||
|
||||
auto abs = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
return Emit(prim::kZerosLike, {node});
|
||||
} else if (abs->isa<abstract::AbstractMonad>() || abs->isa<abstract::AbstractType>() ||
|
||||
abs->isa<abstract::AbstractNone>()) {
|
||||
// To prevent return input directly, and return a operator for latter cnode link.
|
||||
return node;
|
||||
} else if (abs->isa<abstract::AbstractSequence>()) {
|
||||
auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (sequence_abs->empty()) {
|
||||
return node;
|
||||
}
|
||||
return Emit(prim::kSequenceZerosLike, {node});
|
||||
} else if (abs->isa<abstract::AbstractScalar>()) {
|
||||
return Emit(prim::kZerosLike, {Tensor(0, abs->BuildType())});
|
||||
}
|
||||
if (node->get()->abstract()->isa<abstract::AbstractMonad>()) {
|
||||
return Emit(prim::kZerosLike, {Tensor(0)});
|
||||
}
|
||||
if (node->get()->abstract()->isa<abstract::AbstractNone>()) {
|
||||
return Emit(prim::kZerosLike, {Tensor(0)});
|
||||
}
|
||||
return Emit(prim::kZerosLike, {node});
|
||||
|
||||
MS_LOG(EXCEPTION) << "Cannot emit ZerosLike for " << node->get()->ToString() << " with abstract " << abs;
|
||||
}
|
||||
|
||||
NodePtr Emitter::Fill(double value, const ShapeVector &shape, TypeId data_type) const {
|
||||
|
@ -330,8 +311,8 @@ NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_
|
|||
}
|
||||
|
||||
NodePtrList Emitter::ShapeCalc(const NodePtrList &inputs, const ops::ShapeFunc &shape_func,
|
||||
const ops::InferFunc &infer_func, const std::vector<int64_t> &value_depend_indices,
|
||||
size_t size) const {
|
||||
const ops::InferFunc &infer_func,
|
||||
const std::vector<int64_t> &value_depend_indices) const {
|
||||
MS_EXCEPTION_IF_NULL(shape_func);
|
||||
MS_EXCEPTION_IF_NULL(infer_func);
|
||||
if (inputs.empty()) {
|
||||
|
@ -384,8 +365,13 @@ NodePtrList Emitter::ShapeCalc(const NodePtrList &inputs, const ops::ShapeFunc &
|
|||
{ops::kAttrValueDependIndices, MakeValue(value_depend_indices)},
|
||||
{kAttrPrimitiveTarget, MakeValue("CPU")}});
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
if (size > 1) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
MS_EXCEPTION_IF_NULL(out->get());
|
||||
auto abs = out->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
auto abstract_tuple = abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
for (size_t i = 0; i < abstract_tuple->size(); ++i) {
|
||||
res.push_back(TupleGetItem(out, i));
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -173,7 +173,7 @@ class COMMON_EXPORT Emitter {
|
|||
/// \param[in] size The size of outputs.
|
||||
/// \return NodePtrList, the outputs shape list.
|
||||
NodePtrList ShapeCalc(const NodePtrList &inputs, const ops::ShapeFunc &shape_func, const ops::InferFunc &infer_func,
|
||||
const std::vector<int64_t> &value_depend_indices = {}, size_t size = 1) const;
|
||||
const std::vector<int64_t> &value_depend_indices = {}) const;
|
||||
|
||||
using BlockFunc = std::function<NodePtrList(const Emitter *)>;
|
||||
/// \brief Generate a conditional block.
|
||||
|
|
|
@ -325,6 +325,13 @@ void RegArrayBpropExpanderOps1() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(Triu);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ZerosLike);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Concat);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Reshape);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ExpandDims);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Gather);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(GatherV2);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Tile);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(StridedSlice);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentSum);
|
||||
}
|
||||
|
||||
void RegArrayBpropExpanderOps2() {}
|
||||
|
|
|
@ -38,13 +38,44 @@ NodePtrList GatherDropNegatives(const BpropIRBuilder *ib, const NodePtr ¶ms,
|
|||
if (is_positive_param == nullptr) {
|
||||
is_positive = ib->GreaterEqual(ids, ib->Tensor(0, ib->GetDtype(ids)));
|
||||
auto broadcastable_shape = ib->GetShape(is_positive);
|
||||
auto back_size = ib->GetShape(gathered).size() - ib->GetShape(is_positive).size();
|
||||
for (size_t i = 0; i < back_size; ++i) {
|
||||
broadcastable_shape.push_back(1);
|
||||
}
|
||||
is_positive = ib->Reshape(is_positive, broadcastable_shape);
|
||||
auto gathered_shape = ib->GetShape(gathered);
|
||||
is_positive = ib->LogicalAnd(is_positive, ib->Fill(1.0, gathered_shape, TypeId::kNumberTypeBool));
|
||||
if (IsDynamic(broadcastable_shape) || IsDynamic(gathered_shape)) {
|
||||
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
|
||||
auto is_pos = inputs.at(1);
|
||||
auto gather_rank = inputs.at(0).size();
|
||||
auto is_pos_rank = is_pos.size();
|
||||
|
||||
std::vector<int64_t> res_shape(is_pos.begin(), is_pos.end());
|
||||
if (gather_rank > is_pos_rank) {
|
||||
auto expand_len = gather_rank - is_pos_rank;
|
||||
for (size_t i = 0; i < expand_len; ++i) {
|
||||
res_shape.push_back(1);
|
||||
}
|
||||
}
|
||||
return {res_shape};
|
||||
};
|
||||
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) -> ShapeVector {
|
||||
if (!invalid_indices.empty()) {
|
||||
return {-1};
|
||||
}
|
||||
auto gather_rank = inputs.at(0).size();
|
||||
auto is_pos_rank = inputs.at(1).size();
|
||||
return {SizeToLong(std::max(gather_rank, is_pos_rank))};
|
||||
};
|
||||
|
||||
auto is_positive_shape = ib->ShapeCalc({gathered, is_positive}, shape_func, infer_func, {})[0];
|
||||
is_positive = ib->Reshape(is_positive, is_positive_shape);
|
||||
auto shape_gather = ib->Emit("TupleToTensor", {ib->Emit("Shape", {gathered}), ib->Value(kInt64)});
|
||||
is_positive = ib->LogicalAnd(is_positive, ib->Fill(1.0, shape_gather, TypeId::kNumberTypeBool));
|
||||
} else {
|
||||
auto back_size = ib->GetShape(gathered).size() - ib->GetShape(is_positive).size();
|
||||
for (size_t i = 0; i < back_size; ++i) {
|
||||
broadcastable_shape.push_back(1);
|
||||
}
|
||||
is_positive = ib->Reshape(is_positive, broadcastable_shape);
|
||||
auto ones = ib->Fill(1.0, gathered_shape, TypeId::kNumberTypeBool);
|
||||
is_positive = ib->LogicalAnd(is_positive, ones);
|
||||
}
|
||||
}
|
||||
auto zero_slice = ib->ZerosLike(gathered);
|
||||
return {ib->Select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive};
|
||||
|
@ -117,19 +148,117 @@ NodePtrList TensorScatterPossibleReplacement(const BpropIRBuilder *ib) {
|
|||
return {ib->Cast(dx, ib->GetDtype(x)), ib->ZerosLike(indices), ib->Cast(dupdates, ib->GetDtype(updates))};
|
||||
}
|
||||
|
||||
ShapeArray RegenerateOutputShapeFunc(const ShapeArray &inputs) {
|
||||
auto x_shape = inputs.at(0);
|
||||
auto indices_shape = inputs.at(1);
|
||||
|
||||
auto axis = inputs.at(2);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(axis.size() == 1, "axis should be a scalar.");
|
||||
auto axis_value = axis[0];
|
||||
if (axis_value < 0) {
|
||||
axis_value += x_shape.size();
|
||||
}
|
||||
|
||||
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.begin() + axis_value);
|
||||
out_shape.insert(out_shape.end(), indices_shape.begin(), indices_shape.end());
|
||||
out_shape.insert(out_shape.end(), x_shape.begin() + axis_value + 1, x_shape.end());
|
||||
return {out_shape};
|
||||
}
|
||||
|
||||
ShapeVector RegenerateOutputInferFunc(const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) {
|
||||
if (!invalid_indices.empty()) {
|
||||
return {-1};
|
||||
}
|
||||
|
||||
auto x_rank = inputs.at(0).size();
|
||||
auto indices_rank = inputs.at(1).size();
|
||||
return {static_cast<int64_t>(x_rank + indices_rank)};
|
||||
}
|
||||
|
||||
ShapeArray Perm1ShapeFunc(const ShapeArray &inputs) {
|
||||
auto dout_rank = SizeToLong(inputs.at(0).size());
|
||||
auto indices_rank = SizeToLong(inputs.at(1).size());
|
||||
|
||||
auto axis = inputs.at(2);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(axis.size() == 1, "axis should be a scalar.");
|
||||
auto axis_value = axis[0];
|
||||
if (axis_value < 0) {
|
||||
axis_value += dout_rank - indices_rank + 1;
|
||||
}
|
||||
|
||||
std::vector<int64_t> perm(indices_rank);
|
||||
std::iota(perm.begin(), perm.end(), axis_value);
|
||||
int64_t index_end = std::min(dout_rank, axis_value);
|
||||
for (int64_t i = 0; i < index_end; ++i) {
|
||||
perm.push_back(i);
|
||||
}
|
||||
for (int64_t i = axis_value + indices_rank; i < dout_rank; ++i) {
|
||||
perm.push_back(i);
|
||||
}
|
||||
|
||||
return {perm};
|
||||
}
|
||||
|
||||
ShapeVector PermInferFunc(const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) {
|
||||
if (!invalid_indices.empty()) {
|
||||
return {-1};
|
||||
}
|
||||
|
||||
return {static_cast<int64_t>(inputs.at(0).size())};
|
||||
}
|
||||
|
||||
ShapeArray Perm2ShapeFunc(const ShapeArray &inputs) {
|
||||
auto x_rank = SizeToLong(inputs.at(0).size());
|
||||
auto axis = inputs.at(1);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(axis.size() == 1, "axis should be a scalar.");
|
||||
auto axis_value = axis[0];
|
||||
|
||||
if (axis_value < 0) {
|
||||
axis_value += x_rank;
|
||||
}
|
||||
|
||||
std::vector<int64_t> perm(axis_value);
|
||||
std::iota(perm.begin(), perm.end(), 1);
|
||||
perm.push_back(0);
|
||||
for (int64_t i = 1 + SizeToLong(axis_value); i < x_rank; ++i) {
|
||||
perm.push_back(i);
|
||||
}
|
||||
|
||||
return {perm};
|
||||
}
|
||||
|
||||
NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto axis = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto orig_indices = indices;
|
||||
auto x_shp = ib->GetShape(x);
|
||||
auto out_shp = ib->GetShape(dout);
|
||||
auto ind_shp = ib->GetShape(indices);
|
||||
auto axis_v = CheckRange(GetIntValue(axis), SizeToLong(x_shp.size()));
|
||||
|
||||
if (out_shp.empty()) {
|
||||
dout = ib->Emit("ExpandDims", {dout, ib->Tensor(-1)});
|
||||
}
|
||||
|
||||
if (IsDynamic(x_shp) || IsDynamic(ind_shp) || IsDynamic(out_shp)) {
|
||||
if (ind_shp.empty()) {
|
||||
indices = ib->Emit("ExpandDims", {indices, ib->Tensor(-1)});
|
||||
|
||||
auto out_shp1 = ib->ShapeCalc({x, indices, axis}, RegenerateOutputShapeFunc, RegenerateOutputInferFunc, {2})[0];
|
||||
dout = ib->Reshape(dout, out_shp1);
|
||||
}
|
||||
|
||||
// Calculate perm.
|
||||
auto perm_1 = ib->ShapeCalc({dout, indices, axis}, Perm1ShapeFunc, PermInferFunc, {2})[0];
|
||||
auto perm_2 = ib->ShapeCalc({x, axis}, Perm2ShapeFunc, PermInferFunc, {1})[0];
|
||||
auto values_transpose = ib->Transpose(dout, perm_1);
|
||||
auto x_grad = ib->Emit("UnsortedSegmentSum",
|
||||
{values_transpose, indices, ib->Emit("TupleGetItem", {ib->Emit("Shape", {x}), axis})});
|
||||
x_grad = ib->Transpose(x_grad, perm_2);
|
||||
return {x_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
|
||||
}
|
||||
|
||||
auto axis_v = CheckRange(GetIntValue(axis), SizeToLong(x_shp.size()));
|
||||
if (ind_shp.empty()) {
|
||||
indices = ib->Emit("ExpandDims", {indices, ib->Tensor(-1)});
|
||||
ind_shp = ib->GetShape(indices);
|
||||
|
@ -142,7 +271,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
|||
auto tmp = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_v])});
|
||||
auto perm_2 = GenerateInverseIndex(x_shp, axis_v);
|
||||
auto params_grad = ib->Transpose(tmp, perm_2);
|
||||
return {params_grad, ib->ZerosLike(orig_indices), ib->ZerosLike(axis)};
|
||||
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
|
||||
}
|
||||
|
||||
int64_t RecorrectAxis(int64_t axis, size_t rank) {
|
||||
|
@ -433,8 +562,15 @@ REG_BPROP_BUILDER("StridedSlice").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib)
|
|||
auto end = ib->GetInput(kIndex2);
|
||||
auto strides = ib->GetInput(kIndex3);
|
||||
auto dout = ib->GetInput(kIndex5);
|
||||
auto x_shape = ib->EmitValue(MakeValue(ib->GetShape(x)));
|
||||
auto dx = ib->Emit("StridedSliceGrad", {dout, x_shape, begin, end, strides},
|
||||
auto x_shape_vec = ib->GetShape(x);
|
||||
|
||||
NodePtr x_shape_node;
|
||||
if (IsDynamic(x_shape_vec)) {
|
||||
x_shape_node = ib->Emit("TensorShape", {x});
|
||||
} else {
|
||||
x_shape_node = ib->EmitValue(MakeValue(x_shape_vec));
|
||||
}
|
||||
auto dx = ib->Emit("StridedSliceGrad", {dout, x_shape_node, begin, end, strides},
|
||||
{{"begin_mask", ib->GetAttr("begin_mask")},
|
||||
{"end_mask", ib->GetAttr("end_mask")},
|
||||
{"ellipsis_mask", ib->GetAttr("ellipsis_mask")},
|
||||
|
@ -546,8 +682,14 @@ REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib)
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto shp = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto shapex = ib->GetShape(x);
|
||||
return {ib->Reshape(dout, shapex), ib->ZerosLike(shp)};
|
||||
auto shape_x = ib->GetShape(x);
|
||||
NodePtr dx;
|
||||
if (!IsDynamic(shape_x)) {
|
||||
dx = ib->Reshape(dout, shape_x);
|
||||
} else {
|
||||
dx = ib->Reshape(dout, ib->Emit("TensorShape", {x}));
|
||||
}
|
||||
return {dx, ib->ZerosLike(shp)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -667,11 +809,11 @@ REG_BPROP_BUILDER("Concat").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
for (size_t i = 0; i < input_nums; ++i) {
|
||||
x_tuple.push_back(ib->TupleGetItem(x, i));
|
||||
}
|
||||
auto concat_offset = ib->ShapeCalc(x_tuple, shape_func, infer_func, {}, input_nums);
|
||||
auto concat_offset = ib->ShapeCalc(x_tuple, shape_func, infer_func, {});
|
||||
NodePtrList res;
|
||||
for (size_t i = 0; i < input_nums; ++i) {
|
||||
auto input = ib->Emit("Shape", {ib->TupleGetItem(x, i)});
|
||||
auto slice_out = ib->Emit(kSliceOpName, {dout, concat_offset[i], input});
|
||||
auto slice_out = ib->Emit(kSliceOpName, {dout, concat_offset.at(i), input});
|
||||
res.push_back(slice_out);
|
||||
}
|
||||
|
||||
|
@ -912,8 +1054,14 @@ REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(i
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto axis = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto shapex = ib->GetShape(x);
|
||||
return {ib->Reshape(dout, shapex), ib->ZerosLike(axis)};
|
||||
auto shape_x = ib->GetShape(x);
|
||||
NodePtr dx;
|
||||
if (IsDynamic(shape_x)) {
|
||||
dx = ib->Reshape(dout, ib->Emit("TensorShape", {x}));
|
||||
} else {
|
||||
dx = ib->Reshape(dout, shape_x);
|
||||
}
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -963,10 +1111,33 @@ REG_BPROP_BUILDER("Tile").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto input_multiples = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto shapex = ib->GetShape(x);
|
||||
auto multiples = GetIntList(input_multiples);
|
||||
auto r_shape = TileShape(multiples, shapex);
|
||||
auto axis = Range(0, static_cast<int64_t>(r_shape.size()), 2);
|
||||
|
||||
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
|
||||
// {x_shape, multiples}
|
||||
auto r_shape = TileShape(inputs.at(1), inputs.at(0));
|
||||
ShapeVector axis;
|
||||
size_t axis_sz = r_shape.size() / 2;
|
||||
axis.reserve(axis_sz);
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(axis_sz); ++i) {
|
||||
axis.push_back(i * 2);
|
||||
}
|
||||
|
||||
return {r_shape, axis};
|
||||
};
|
||||
|
||||
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) -> ShapeVector {
|
||||
if (!invalid_indices.empty()) {
|
||||
return {-1, -1};
|
||||
}
|
||||
auto x_sz = static_cast<int64_t>(inputs.at(0).size());
|
||||
auto multiples_sz = static_cast<int64_t>(inputs.at(1).size());
|
||||
auto max_sz = x_sz > multiples_sz ? x_sz : multiples_sz;
|
||||
return {2 * max_sz, max_sz};
|
||||
};
|
||||
|
||||
auto calc_res = ib->ShapeCalc({x, input_multiples}, shape_func, infer_func, {1});
|
||||
auto r_shape = calc_res[0];
|
||||
auto axis = calc_res[1];
|
||||
auto dout_reshaped = ib->Reshape(dout, r_shape);
|
||||
auto dout_dtype = ib->GetDtype(dout_reshaped)->type_id();
|
||||
NodePtr dx;
|
||||
|
@ -974,15 +1145,16 @@ REG_BPROP_BUILDER("Tile").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
|||
if (need_reduce.first) {
|
||||
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
|
||||
dout_reshaped = ib->Cast(dout_reshaped, kFloat32);
|
||||
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
|
||||
dx = ib->Emit("ReduceSum", {dout_reshaped, axis}, {{"keep_dims", MakeValue(false)}});
|
||||
dx = ib->Cast(dx, dout_dtype);
|
||||
} else {
|
||||
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
|
||||
dx = ib->Emit("ReduceSum", {dout_reshaped, axis}, {{"keep_dims", MakeValue(false)}});
|
||||
}
|
||||
} else {
|
||||
dx = ib->Reshape(dout_reshaped, need_reduce.second);
|
||||
}
|
||||
dx = ib->Reshape(dx, shapex);
|
||||
auto shape_x = ib->Emit("Shape", {x});
|
||||
dx = ib->Reshape(dx, shape_x);
|
||||
return {dx, ib->ZerosLike(input_multiples)};
|
||||
});
|
||||
|
||||
|
|
|
@ -90,8 +90,7 @@ bool ShapeCalcCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs
|
|||
}
|
||||
|
||||
std::vector<KernelAttr> ShapeCalcCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64)}};
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
|
|
|
@ -72,6 +72,7 @@ int64_t GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abs
|
|||
auto value = input_args[kInputIndex2]->BuildValue();
|
||||
if (!IsValueKnown(value)) {
|
||||
num_segments_v = abstract::Shape::kShapeDimAny;
|
||||
return num_segments_v;
|
||||
}
|
||||
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
|
||||
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
|
|
Loading…
Reference in New Issue