add compute grad interface
This commit is contained in:
parent
3e5d01d3fc
commit
cc698d1e08
|
@ -37,6 +37,7 @@ NodePtrList BpropBuilder::Run(const NodePtrList &inputs, const mindspore::HashMa
|
|||
inputs_ptr_ = &inputs;
|
||||
attrs_ptr_ = &attrs;
|
||||
instance_name_ = instance_name;
|
||||
input_size_ = inputs.size() - kIndex2;
|
||||
return handle.func(this);
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,6 @@ class COMMON_EXPORT BpropBuilder : public Emitter {
|
|||
/// \brief Run irbuilder to generate a graph
|
||||
NodePtrList Run(const NodePtrList &inputs, const mindspore::HashMap<std::string, ValuePtr> &attrs,
|
||||
const BpropHandle &handle, const std::string &instance_name);
|
||||
|
||||
ValuePtr GetAttr(const std::string &attr) const;
|
||||
template <typename S>
|
||||
S GetAttr(const std::string &attr) const {
|
||||
|
@ -113,12 +112,14 @@ class COMMON_EXPORT BpropBuilder : public Emitter {
|
|||
std::string GetInstanceName() const { return instance_name_; }
|
||||
NodePtr TanhGrad(const NodePtr &y, const NodePtr &dy) { return Emit("TanhGrad", {y, dy}); }
|
||||
virtual NodePtr OutZeros(const NodePtr &node) { return ZerosLike(node); }
|
||||
size_t input_size() const { return input_size_; }
|
||||
|
||||
protected:
|
||||
std::string name_;
|
||||
std::string instance_name_;
|
||||
const NodePtrList *inputs_ptr_{nullptr};
|
||||
const mindspore::HashMap<std::string, ValuePtr> *attrs_ptr_{nullptr};
|
||||
size_t input_size_;
|
||||
};
|
||||
|
||||
class IrBuilder : public BpropBuilder {
|
||||
|
|
|
@ -259,48 +259,62 @@ std::vector<int64_t> GetIntList(const NodePtr &node) {
|
|||
return GetIntList(value);
|
||||
}
|
||||
|
||||
NodePtrList BinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
|
||||
const NodePtr &dy, size_t shift) {
|
||||
NodePtr StaticBinopGradCommon(BpropBuilder *ib, const NodePtr &dx, const ShapeArray &shape,
|
||||
const ShapeArray &broadcast_shape, size_t shift, size_t index, bool *is_dynamic_shape) {
|
||||
NodePtr reduce_dx = dx;
|
||||
if (broadcast_shape[kIndex0].empty() || broadcast_shape[kIndex1].empty()) {
|
||||
if (broadcast_shape[index].empty()) {
|
||||
if (shift) {
|
||||
std::vector<int64_t> axis(broadcast_shape[index ^ 1].size());
|
||||
std::iota(axis.begin(), axis.begin(), 0LL);
|
||||
reduce_dx = ib->ReduceSum(reduce_dx, axis);
|
||||
} else {
|
||||
reduce_dx = ib->ReduceSum(reduce_dx);
|
||||
}
|
||||
}
|
||||
} else if (!IsDynamic(broadcast_shape[0]) && !IsDynamic(broadcast_shape[1])) {
|
||||
std::vector<std::vector<int64_t>> bc_axis = BroadcastGradientArgsInferValue(broadcast_shape[0], broadcast_shape[1]);
|
||||
if (!bc_axis[index].empty()) {
|
||||
reduce_dx = ib->ReduceSum(reduce_dx, bc_axis[index], ib->GetRank(reduce_dx) == shape[index].size());
|
||||
}
|
||||
if (ib->GetRank(reduce_dx) != shape[index].size()) {
|
||||
reduce_dx = ib->Reshape(reduce_dx, shape[index]);
|
||||
}
|
||||
} else {
|
||||
*is_dynamic_shape = true;
|
||||
}
|
||||
return reduce_dx;
|
||||
}
|
||||
|
||||
NodePtrList BinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx, const NodePtr &dy,
|
||||
size_t shift) {
|
||||
// Common grad definition for binary operations with shift.
|
||||
// The function is usually used in backprop op to reduce additional dimensions
|
||||
// created by broadcasting.
|
||||
NodePtrList inputs{x, y};
|
||||
ShapeArray shape{ib->GetShape(inputs[0]), ib->GetShape(inputs[1])};
|
||||
ShapeArray shape{ib->GetShape(inputs[kIndex0]), ib->GetShape(inputs[kIndex1])};
|
||||
NodePtrList reduce = {dx, dy};
|
||||
if (IsDynamicRank(shape[0]) || IsDynamicRank(shape[1])) {
|
||||
if (IsDynamicRank(shape[kIndex0]) || IsDynamicRank(shape[kIndex1])) {
|
||||
return DynBinopGradCommon(ib, x, y, dx, dy, shift);
|
||||
}
|
||||
if (shape[kIndex0].size() <= shift && shape[kIndex0].size() == shape[kIndex1].size()) {
|
||||
return reduce;
|
||||
}
|
||||
ShapeVector broadcast_shape[kDim2];
|
||||
ShapeArray broadcast_shape(kDim2);
|
||||
for (size_t i = 0; i < kDim2; i++) {
|
||||
broadcast_shape[i] = ShapeVector(shape[i].begin(), shape[i].end() - shift);
|
||||
}
|
||||
|
||||
if (broadcast_shape[0].empty() || broadcast_shape[1].empty()) {
|
||||
for (size_t i = 0; i < kDim2; i++) {
|
||||
if (broadcast_shape[i].empty()) {
|
||||
if (shift) {
|
||||
std::vector<int64_t> axis(broadcast_shape[i ^ 1].size());
|
||||
std::iota(axis.begin(), axis.begin(), 0LL);
|
||||
reduce[i] = ib->ReduceSum(reduce[i], axis);
|
||||
} else {
|
||||
reduce[i] = ib->ReduceSum(reduce[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (!IsDynamic(broadcast_shape[0]) && !IsDynamic(broadcast_shape[1])) {
|
||||
std::vector<std::vector<int64_t>> bc_axis = BroadcastGradientArgsInferValue(broadcast_shape[0], broadcast_shape[1]);
|
||||
for (size_t i = 0; i < kDim2; i++) {
|
||||
if (!bc_axis[i].empty()) {
|
||||
reduce[i] = ib->ReduceSum(reduce[i], bc_axis[i], ib->GetRank(reduce[i]) == shape[i].size());
|
||||
}
|
||||
if (ib->GetRank(reduce[i]) != shape[i].size()) {
|
||||
reduce[i] = ib->Reshape(reduce[i], shape[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bool is_x_shape_dynamic = false;
|
||||
bool is_y_shape_dynamic = false;
|
||||
if (dx != nullptr) {
|
||||
reduce[kIndex0] =
|
||||
StaticBinopGradCommon(ib, reduce[kIndex0], shape, broadcast_shape, shift, kIndex0, &is_x_shape_dynamic);
|
||||
}
|
||||
if (dy != nullptr) {
|
||||
reduce[kIndex1] =
|
||||
StaticBinopGradCommon(ib, reduce[kIndex1], shape, broadcast_shape, shift, kIndex1, &is_y_shape_dynamic);
|
||||
}
|
||||
if (is_x_shape_dynamic || is_y_shape_dynamic) {
|
||||
return DynBinopGradCommon(ib, x, y, dx, dy, shift);
|
||||
}
|
||||
return reduce;
|
||||
|
|
|
@ -89,20 +89,27 @@ NodePtrList MinimumMaximumGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr
|
|||
auto half_ratio = ib->Emit("FillV2", {ib->Shape(dout), ib->Tensor(2, ib->GetDtype(dout))});
|
||||
auto half_dout = ib->Div(dout, half_ratio);
|
||||
NodePtr equal_mask = ib->Equal(x, y);
|
||||
auto grad_x = ib->Select(equal_mask, half_dout, dout);
|
||||
auto grad_y = ib->Select(equal_mask, half_dout, dout);
|
||||
|
||||
auto zeros = ib->Emit("FillV2", {ib->Shape(dout), ib->Tensor(0, ib->GetDtype(dout))});
|
||||
NodePtr is_less = ib->Less(x, y);
|
||||
NodePtr is_greater = ib->Greater(x, y);
|
||||
if (is_minimum) {
|
||||
grad_x = ib->Select(is_greater, zeros, grad_x);
|
||||
grad_y = ib->Select(is_less, zeros, grad_y);
|
||||
} else {
|
||||
grad_x = ib->Select(is_less, zeros, grad_x);
|
||||
grad_y = ib->Select(is_greater, zeros, grad_y);
|
||||
NodePtr grad_x = nullptr;
|
||||
NodePtr grad_y = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
grad_x = ib->Select(equal_mask, half_dout, dout);
|
||||
if (is_minimum) {
|
||||
grad_x = ib->Select(is_greater, zeros, grad_x);
|
||||
} else {
|
||||
grad_x = ib->Select(is_less, zeros, grad_x);
|
||||
}
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
grad_y = ib->Select(equal_mask, half_dout, dout);
|
||||
if (is_minimum) {
|
||||
grad_y = ib->Select(is_less, zeros, grad_y);
|
||||
} else {
|
||||
grad_y = ib->Select(is_greater, zeros, grad_y);
|
||||
}
|
||||
}
|
||||
|
||||
return BinopGradCommon(ib, x, y, grad_x, grad_y);
|
||||
}
|
||||
|
||||
|
@ -473,7 +480,15 @@ REG_BPROP_BUILDER("Add").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
return BinopGradCommon(ib, x, y, dout, dout);
|
||||
NodePtr dx = nullptr;
|
||||
NodePtr dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
dx = dout;
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
dy = dout;
|
||||
}
|
||||
return BinopGradCommon(ib, x, y, dx, dy);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AddExt").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -526,8 +541,14 @@ REG_BPROP_BUILDER("Mul").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
|
||||
MS_EXCEPTION(TypeError) << "For 'Mul', gradient not support for complex type currently.";
|
||||
}
|
||||
auto bc_dx = ib->Mul(y, dout);
|
||||
auto bc_dy = ib->Mul(x, dout);
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
bc_dx = ib->Mul(y, dout);
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = ib->Mul(x, dout);
|
||||
}
|
||||
return BinopGradCommon(ib, x, y, bc_dx, bc_dy);
|
||||
});
|
||||
|
||||
|
@ -535,7 +556,15 @@ REG_BPROP_BUILDER("Sub").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
return BinopGradCommon(ib, x, y, dout, ib->Emit(kNegOpName, {dout}));
|
||||
NodePtr dx = nullptr;
|
||||
NodePtr dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
dx = dout;
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
dy = ib->Emit(kNegOpName, {dout});
|
||||
}
|
||||
return BinopGradCommon(ib, x, y, dout, dy);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Div").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -543,14 +572,20 @@ REG_BPROP_BUILDER("Div").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
|||
auto y = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto bc_x = ib->Emit(kDivOpName, {dout, y});
|
||||
auto bc_y = -(bc_x * out);
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
auto x_dtype_id = ib->GetDtypeId(x);
|
||||
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
|
||||
auto result = BinopGradCommon(ib, x, y, bc_x, bc_y);
|
||||
return {ib->Conj(result[0]), ib->Conj(result[1])};
|
||||
bc_dx = ib->Emit(kDivOpName, {dout, y});
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = -(bc_dx * out);
|
||||
}
|
||||
return BinopGradCommon(ib, x, y, bc_x, bc_y);
|
||||
auto result = BinopGradCommon(ib, x, y, bc_dx, bc_dy);
|
||||
bool is_complex = (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128);
|
||||
if (is_complex) {
|
||||
result[kIndex0] = ib->Conj(result[kIndex0]);
|
||||
result[kIndex1] = x->need_compute_grad_out() ? ib->Conj(result[kIndex1]) : ib->OutZeros(y);
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BitwiseAnd").SetUnusedInputs({i0, i1, i2, i3}).SetBody(ReturnZeros);
|
||||
|
@ -714,8 +749,14 @@ REG_BPROP_BUILDER("Atan2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto tmp = ib->RealDiv(dout, (ib->Add((ib->Emit("Square", {x})), (ib->Emit("Square", {y})))));
|
||||
auto bc_dx = ib->Mul(tmp, y);
|
||||
auto bc_dy = ib->Mul(tmp, (ib->Emit("Neg", {x})));
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
bc_dx = ib->Mul(tmp, y);
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = ib->Mul(tmp, (ib->Emit("Neg", {x})));
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
|
@ -781,11 +822,17 @@ REG_BPROP_BUILDER("Pow").SetBody(BODYFUNC(ib) {
|
|||
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
|
||||
MS_EXCEPTION(TypeError) << "For 'Pow', gradient not support for complex type currently.";
|
||||
}
|
||||
auto bc_dx = ib->Mul((ib->Mul(power, (ib->Pow(x, ib->Sub(power, ib->Tensor(1.0, ib->GetDtype(x))))))), dout);
|
||||
x =
|
||||
ib->Select(ib->Less(x, ib->Tensor(0, ib->GetDtype(x))), ib->Fill(1.0, ib->Shape(x), ib->GetDtype(x)->type_id()), x);
|
||||
auto bc_dpower = ib->Mul((ib->Mul(out, (ib->Log(x)))), dout);
|
||||
return {BinopGradCommon(ib, x, power, bc_dx, bc_dpower)};
|
||||
NodePtrList grad_outputs{ib->input_size(), nullptr};
|
||||
if (x->need_compute_grad_out()) {
|
||||
grad_outputs[kIndex0] =
|
||||
ib->Mul((ib->Mul(power, (ib->Pow(x, ib->Sub(power, ib->Tensor(1.0, ib->GetDtype(x))))))), dout);
|
||||
}
|
||||
if (power->need_compute_grad_out()) {
|
||||
x = ib->Select(ib->Less(x, ib->Tensor(0, ib->GetDtype(x))), ib->Fill(1.0, ib->Shape(x), ib->GetDtype(x)->type_id()),
|
||||
x);
|
||||
grad_outputs[kIndex1] = ib->Mul((ib->Mul(out, (ib->Log(x)))), dout);
|
||||
}
|
||||
return {BinopGradCommon(ib, x, power, grad_outputs[kIndex0], grad_outputs[kIndex1])};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Exp").SetBody(BODYFUNC(ib) {
|
||||
|
@ -1148,9 +1195,12 @@ REG_BPROP_BUILDER("RealDiv").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
|||
}
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto bc_x = ib->RealDiv(dout, y);
|
||||
auto bc_y = -(bc_x * out);
|
||||
return {BinopGradCommon(ib, x, y, bc_x, bc_y)};
|
||||
auto bc_dx = ib->RealDiv(dout, y);
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = -(bc_dx * out);
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DivNoNan").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -1158,9 +1208,12 @@ REG_BPROP_BUILDER("DivNoNan").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
|||
auto y = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto bc_x = ib->DivNoNan(dout, y);
|
||||
auto bc_y = -(bc_x * out);
|
||||
return {BinopGradCommon(ib, x, y, bc_x, bc_y)};
|
||||
auto bc_dx = ib->DivNoNan(dout, y);
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = -(bc_dx * out);
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Xdivy").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -1168,10 +1221,16 @@ REG_BPROP_BUILDER("Xdivy").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto x_dtype = ib->GetDtype(x);
|
||||
auto not_zero_x = ib->Cast(ib->NotEqual(x, ib->Tensor(0.0, x_dtype)), x_dtype);
|
||||
auto bc_x = (ib->Xdivy(not_zero_x, y)) * dout;
|
||||
auto bc_y = (ib->Xdivy(-x, ib->Emit("Square", {y}))) * dout;
|
||||
return {BinopGradCommon(ib, x, y, bc_x, bc_y)};
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
auto not_zero_x = ib->Cast(ib->NotEqual(x, ib->Tensor(0.0, x_dtype)), x_dtype);
|
||||
bc_dx = (ib->Xdivy(not_zero_x, y)) * dout;
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = (ib->Xdivy(-x, ib->Emit("Square", {y}))) * dout;
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FloorDiv").SetUnusedInputs({i0, i1, i2, i3}).SetBody(ReturnZeros);
|
||||
|
@ -1180,11 +1239,16 @@ REG_BPROP_BUILDER("FloorMod").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto bc_x = dout;
|
||||
auto bc_y = (-dout) * (ib->FloorDiv(x, y));
|
||||
bc_x = ib->Cast(bc_x, ib->GetDtype(x));
|
||||
bc_y = ib->Cast(bc_y, ib->GetDtype(y));
|
||||
return {BinopGradCommon(ib, x, y, bc_x, bc_y)};
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
bc_dx = ib->Cast(dout, ib->GetDtype(x));
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = (-dout) * (ib->FloorDiv(x, y));
|
||||
bc_dy = ib->Cast(bc_dy, ib->GetDtype(y));
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TruncateDiv").SetUnusedInputs({i0, i1, i2, i3}).SetBody(ReturnZeros);
|
||||
|
@ -1193,20 +1257,31 @@ REG_BPROP_BUILDER("TruncateMod").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto bc_x = dout;
|
||||
auto bc_y = (-dout) * (ib->Emit("TruncateDiv", {x, y}));
|
||||
return {BinopGradCommon(ib, x, y, bc_x, bc_y)};
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
bc_dx = dout;
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = (-dout) * (ib->Emit("TruncateDiv", {x, y}));
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Mod").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto bc_x = dout;
|
||||
auto bc_y = (-dout) * (ib->FloorDiv(x, y));
|
||||
bc_x = ib->Cast(bc_x, ib->GetDtype(x));
|
||||
bc_y = ib->Cast(bc_y, ib->GetDtype(y));
|
||||
return {BinopGradCommon(ib, x, y, bc_x, bc_y)};
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
bc_dx = ib->Cast(dout, ib->GetDtype(x));
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = (-dout) * (ib->FloorDiv(x, y));
|
||||
bc_dy = ib->Cast(bc_dy, ib->GetDtype(y));
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Xlogy").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -1214,10 +1289,16 @@ REG_BPROP_BUILDER("Xlogy").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto x_dtype = ib->GetDtype(x);
|
||||
auto not_zero_x = ib->Cast(ib->NotEqual(x, ib->Tensor(0.0, x_dtype)), x_dtype);
|
||||
auto bc_x = ib->Xlogy(not_zero_x, y) * dout;
|
||||
auto bc_y = ib->Xdivy(x, y) * dout;
|
||||
return {BinopGradCommon(ib, x, y, bc_x, bc_y)};
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
auto not_zero_x = ib->Cast(ib->NotEqual(x, ib->Tensor(0.0, x_dtype)), x_dtype);
|
||||
bc_dx = ib->Xlogy(not_zero_x, y) * dout;
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = ib->Xdivy(x, y) * dout;
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Sqrt").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -1307,11 +1388,17 @@ REG_BPROP_BUILDER("Hypot").SetBody(BODYFUNC(ib) {
|
|||
auto x2_f32 = ib->Cast(x2, kFloat32);
|
||||
auto out_f32 = ib->Cast(out, kFloat32);
|
||||
auto dout_f32 = ib->Cast(dout, kFloat32);
|
||||
auto dx1 = ib->Mul(ib->Div(x1_f32, out_f32), dout_f32);
|
||||
auto dx2 = ib->Mul(ib->Div(x2_f32, out_f32), dout_f32);
|
||||
NodePtr dx1 = nullptr;
|
||||
NodePtr dx2 = nullptr;
|
||||
if (x1->need_compute_grad_out()) {
|
||||
dx1 = ib->Mul(ib->Div(x1_f32, out_f32), dout_f32);
|
||||
}
|
||||
if (x2->need_compute_grad_out()) {
|
||||
dx2 = ib->Mul(ib->Div(x2_f32, out_f32), dout_f32);
|
||||
}
|
||||
auto tmp = BinopGradCommon(ib, x1_f32, x2_f32, dx1, dx2);
|
||||
auto result_dx1 = ib->Cast(tmp[0], ib->GetDtype(x1));
|
||||
auto result_dx2 = ib->Cast(tmp[1], ib->GetDtype(x2));
|
||||
auto result_dx1 = x1->need_compute_grad_out() ? ib->Cast(tmp[0], ib->GetDtype(x1)) : ib->OutZeros(x1);
|
||||
auto result_dx2 = x2->need_compute_grad_out() ? ib->Cast(tmp[1], ib->GetDtype(x2)) : ib->OutZeros(x2);
|
||||
return {result_dx1, result_dx2};
|
||||
});
|
||||
|
||||
|
@ -1920,7 +2007,15 @@ REG_BPROP_BUILDER("AddV2").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
return {BinopGradCommon(ib, x, y, dout, dout)};
|
||||
NodePtr bc_dx = nullptr;
|
||||
NodePtr bc_dy = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
bc_dx = dout;
|
||||
}
|
||||
if (y->need_compute_grad_out()) {
|
||||
bc_dy = dout;
|
||||
}
|
||||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Addcdiv").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -2191,20 +2286,22 @@ REG_BPROP_BUILDER("BatchMatMul").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
auto w = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
||||
NodePtr dx;
|
||||
if (ta) {
|
||||
dx = ib->BatchMatMul(w, dout, tb, true);
|
||||
} else {
|
||||
dx = ib->BatchMatMul(dout, w, false, !tb);
|
||||
NodePtr dx = nullptr;
|
||||
if (x->need_compute_grad_out()) {
|
||||
if (ta) {
|
||||
dx = ib->BatchMatMul(w, dout, tb, true);
|
||||
} else {
|
||||
dx = ib->BatchMatMul(dout, w, false, !tb);
|
||||
}
|
||||
}
|
||||
|
||||
NodePtr dw;
|
||||
if (tb) {
|
||||
dw = ib->BatchMatMul(dout, x, true, ta);
|
||||
} else {
|
||||
dw = ib->BatchMatMul(x, dout, !ta, false);
|
||||
NodePtr dw = nullptr;
|
||||
if (w->need_compute_grad_out()) {
|
||||
if (tb) {
|
||||
dw = ib->BatchMatMul(dout, x, true, ta);
|
||||
} else {
|
||||
dw = ib->BatchMatMul(x, dout, !ta, false);
|
||||
}
|
||||
}
|
||||
|
||||
return BinopGradCommon(ib, x, w, dx, dw, 2);
|
||||
});
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ class COMMON_EXPORT Node : public std::enable_shared_from_this<Node> {
|
|||
virtual bool is_used_value() const {
|
||||
MS_EXCEPTION(NotImplementedError) << "Base Node not implement is_used_value() method";
|
||||
}
|
||||
virtual bool need_compute_grad_out() const { return true; }
|
||||
|
||||
protected:
|
||||
// hold the emitter who created this node.
|
||||
|
@ -112,10 +113,13 @@ class COMMON_EXPORT FuncNode : public Node {
|
|||
std::string ToString() const override { return value_->ToString(); }
|
||||
void set_debug_info(const std::string &debug_info) override {}
|
||||
std::string debug_info() const override { return ""; }
|
||||
bool need_compute_grad_out() const override { return need_compute_grad_out_; }
|
||||
void set_need_compute_grad_out(bool need_compute_grad_out) { need_compute_grad_out_ = need_compute_grad_out; }
|
||||
|
||||
private:
|
||||
AbstractBasePtr abstract_;
|
||||
InputType input_type_;
|
||||
bool need_compute_grad_out_{true};
|
||||
};
|
||||
using FuncNodePtr = std::shared_ptr<FuncNode>;
|
||||
} // namespace expander
|
||||
|
|
|
@ -252,6 +252,7 @@ void FuncBuilder::SetInputs(std::string instance_name, const std::vector<NodePtr
|
|||
instance_name_ = std::move(instance_name);
|
||||
inputs_ptr_ = inputs;
|
||||
attrs_ptr_ = attrs_ptr;
|
||||
input_size_ = inputs->size() - kSizeTwo;
|
||||
}
|
||||
|
||||
NodePtrList FuncBuilder::FlattenNode(const NodePtr &input) {
|
||||
|
|
|
@ -176,6 +176,27 @@ bool IsValidTensorInput(const ValuePtr &v) {
|
|||
MS_EXCEPTION_IF_NULL(v);
|
||||
return v->isa<tensor::Tensor>() || v->isa<tensor::MetaSparseTensor>();
|
||||
}
|
||||
|
||||
bool IsNeedComputeGrad(const ValuePtr &input) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<tensor::Tensor>()) {
|
||||
auto input_tensor = input->cast<tensor::TensorPtr>();
|
||||
auto auto_grad_meta_data = input_tensor->auto_grad_meta_data();
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
|
||||
auto variable = auto_grad_meta_data->variable();
|
||||
if (variable != nullptr && variable->is_need_grad()) {
|
||||
return true;
|
||||
}
|
||||
} else if (input->isa<ValueSequence>()) {
|
||||
auto seq = input->cast<ValueSequencePtr>();
|
||||
if (!seq->value().empty() && !seq->value().front()->isa<tensor::Tensor>()) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(seq->value().begin(), seq->value().end(),
|
||||
[](const ValuePtr &val) { return IsNeedComputeGrad(val); });
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TensorPtrList FuncBackwardNode::CallBackward(const TensorPtrList &gradients_in) {
|
||||
|
@ -187,7 +208,12 @@ TensorPtrList FuncBackwardNode::CallBackward(const TensorPtrList &gradients_in)
|
|||
const std::vector<NodePtr> cal_grads_node = func()(&ir_builder);
|
||||
ValuePtrList cal_grads_values;
|
||||
std::transform(cal_grads_node.begin(), cal_grads_node.end(), std::back_inserter(cal_grads_values),
|
||||
[](const NodePtr &node) { return node->Value(); });
|
||||
[](const NodePtr &node) -> ValuePtr {
|
||||
if (node == nullptr) {
|
||||
return kNone;
|
||||
}
|
||||
return node->Value();
|
||||
});
|
||||
auto gradients = PostProcess(cal_grads_values);
|
||||
MS_LOG(DEBUG) << "End CallBackward: " << name();
|
||||
return gradients;
|
||||
|
@ -197,7 +223,9 @@ NodePtrList FuncBackwardNode::PreProcess(const TensorPtrList &dout, FuncBuilder
|
|||
NodePtrList node_inputs;
|
||||
node_inputs.reserve(op_inputs_.size() + kSizeFive);
|
||||
for (size_t i = 0; i < op_inputs_.size(); ++i) {
|
||||
(void)node_inputs.emplace_back(emitter->NewFuncNode(op_inputs_[i], input_abstract_[i], grad_type_[i]));
|
||||
auto func_node = emitter->NewFuncNode(op_inputs_[i], input_abstract_[i], grad_type_[i]);
|
||||
func_node->set_need_compute_grad_out(IsNeedComputeGrad(op_inputs_[i]));
|
||||
(void)node_inputs.emplace_back(func_node);
|
||||
}
|
||||
(void)node_inputs.emplace_back(emitter->NewFuncNode(op_output_, out_abstract_, InputType::kOpOutput));
|
||||
if (dout.size() == kSizeOne) {
|
||||
|
@ -730,7 +758,8 @@ void FuncGrad::CheckSensShapeAndType(const ValuePtr &sens_gradient) {
|
|||
}
|
||||
}
|
||||
|
||||
void FuncGrad::PruningGradGraph(const TensorPtrList &weights, const GradAttr &grad_attr, const std::vector<size_t> &grad_position) {
|
||||
void FuncGrad::PruningGradGraph(const TensorPtrList &weights, const GradAttr &grad_attr,
|
||||
const std::vector<size_t> &grad_position) {
|
||||
mindspore::HashSet<size_t> grad_pos_list{grad_position.begin(), grad_position.end()};
|
||||
// Pruning inputs by position in grad graph
|
||||
if (grad_attr.get_by_position) {
|
||||
|
@ -775,7 +804,7 @@ void FuncGrad::PruningGradGraph(const TensorPtrList &weights, const GradAttr &gr
|
|||
continue;
|
||||
}
|
||||
bool is_need_grad = false;
|
||||
for (const auto &edge: variable->func_node()->next_edges()) {
|
||||
for (const auto &edge : variable->func_node()->next_edges()) {
|
||||
is_need_grad = is_need_grad || edge.variable->is_need_grad();
|
||||
}
|
||||
if (!is_need_grad) {
|
||||
|
|
Loading…
Reference in New Issue