fix fake prim bug and others
This commit is contained in:
parent
3c800dc56f
commit
b44cea3cf7
|
@ -41,17 +41,58 @@ constexpr char kAttrZerosLikeCOO[] = "zero_like_coo_node";
|
|||
constexpr char kAttrOnesLikeCSR[] = "ones_like_csr_node";
|
||||
constexpr char kAttrOnesLikeCOO[] = "ones_like_coo_node";
|
||||
enum class SpecialType { kZerosLikeType = 0, kOnesLikeType = 1 };
|
||||
const std::map<SpecialType, std::shared_ptr<Primitive>> value_type{{SpecialType::kZerosLikeType, prim::kPrimZerosLike},
|
||||
const std::map<SpecialType, std::shared_ptr<Primitive>> kValueType{{SpecialType::kZerosLikeType, prim::kPrimZerosLike},
|
||||
{SpecialType::kOnesLikeType, prim::kPrimOnesLike}};
|
||||
const std::vector<PrimitivePtr> kGradBlackList{prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient,
|
||||
prim::kPrimUpdateState, prim::kPrimNPUAllocFloatStatus};
|
||||
AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value, const SpecialType &type);
|
||||
void ClearDeviceAddress(const ValuePtr &value);
|
||||
void ClearDeviceAddress(const ValuePtr &value) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(value, &tensors);
|
||||
for (auto tensor : tensors) {
|
||||
tensor->set_device_address(nullptr);
|
||||
tensor->set_is_forward_output(false);
|
||||
}
|
||||
}
|
||||
|
||||
ValuePtr FilterSensValues(const ValuePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>() || value->isa<tensor::COOTensor>() || value->isa<tensor::CSRTensor>()) {
|
||||
return value;
|
||||
} else if (value->isa<ValueSequence>()) {
|
||||
std::vector<ValuePtr> value_list;
|
||||
auto value_seq = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_seq);
|
||||
for (auto filter_value : value_seq->value()) {
|
||||
if (FilterSensValues(filter_value) != nullptr) {
|
||||
(void)value_list.emplace_back(filter_value);
|
||||
}
|
||||
}
|
||||
return std::make_shared<ValueTuple>(value_list);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "value type: " << value->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsPrimNeedGrad(const PrimitivePtr &prim) {
|
||||
for (const auto &no_need_grad_prim : kGradBlackList) {
|
||||
if (IsPrimitiveEquals(prim, no_need_grad_prim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr BuildSpecialLikeCSRTensor(const FuncGraphPtr &tape, const ValuePtr &value, const SpecialType &type) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
|
||||
auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor);
|
||||
auto indptr = csr_tensor->GetIndptr();
|
||||
auto cloned_indptr = ShallowCopyTensorValue(indptr);
|
||||
ClearDeviceAddress(cloned_indptr);
|
||||
|
||||
auto indptr_node = NewValueNode(cloned_indptr);
|
||||
indptr_node->set_abstract(cloned_indptr->ToAbstract()->Broaden());
|
||||
auto indices = csr_tensor->GetIndices();
|
||||
|
@ -59,11 +100,13 @@ AnfNodePtr BuildSpecialLikeCSRTensor(const FuncGraphPtr &tape, const ValuePtr &v
|
|||
ClearDeviceAddress(cloned_indices);
|
||||
auto indices_node = NewValueNode(cloned_indices);
|
||||
indices_node->set_abstract(cloned_indices->ToAbstract()->Broaden());
|
||||
|
||||
auto data = csr_tensor->GetValues();
|
||||
auto cloned_data = ShallowCopyTensorValue(data);
|
||||
ClearDeviceAddress(cloned_data);
|
||||
auto value_node = NewValueNode(cloned_data);
|
||||
value_node->set_abstract(cloned_data->ToAbstract()->Broaden());
|
||||
|
||||
auto zero_like_value = BuildSpecialLikeValue(tape, cloned_data, type);
|
||||
auto shape = csr_tensor->shape();
|
||||
auto value_shape = NewValueNode(shape);
|
||||
|
@ -87,11 +130,13 @@ AnfNodePtr BuildSpecialLikeCSRTensor(const FuncGraphPtr &tape, const ValuePtr &v
|
|||
|
||||
AnfNodePtr BuildSpecialLikeCOOTensor(const FuncGraphPtr &tape, const ValuePtr &value, const SpecialType &type) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
|
||||
auto coo_tensor = value->cast<tensor::COOTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(coo_tensor);
|
||||
auto indices = coo_tensor->GetIndices();
|
||||
auto cloned_indices = ShallowCopyTensorValue(indices);
|
||||
ClearDeviceAddress(cloned_indices);
|
||||
|
||||
auto indices_node = NewValueNode(cloned_indices);
|
||||
indices_node->set_abstract(cloned_indices->ToAbstract()->Broaden());
|
||||
auto data = coo_tensor->GetValues();
|
||||
|
@ -99,6 +144,7 @@ AnfNodePtr BuildSpecialLikeCOOTensor(const FuncGraphPtr &tape, const ValuePtr &v
|
|||
ClearDeviceAddress(cloned_data);
|
||||
auto value_node = NewValueNode(cloned_data);
|
||||
value_node->set_abstract(cloned_data->ToAbstract()->Broaden());
|
||||
|
||||
auto special_like_value = BuildSpecialLikeValue(tape, cloned_data, type);
|
||||
auto shape = coo_tensor->shape();
|
||||
auto value_shape = NewValueNode(shape);
|
||||
|
@ -124,7 +170,7 @@ AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value
|
|||
if (value->isa<tensor::Tensor>() || value->isa<Scalar>()) {
|
||||
auto vlaue_node = NewValueNode(value);
|
||||
vlaue_node->set_abstract(value->ToAbstract()->Broaden());
|
||||
auto primitive = value_type.at(type);
|
||||
auto primitive = kValueType.at(type);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto special_like_value = tape->NewCNode({NewValueNode(primitive), vlaue_node});
|
||||
special_like_value->set_abstract(value->ToAbstract()->Broaden());
|
||||
|
@ -137,7 +183,7 @@ AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value
|
|||
std::vector<AnfNodePtr> args;
|
||||
(void)args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
auto tuple = value->cast<ValueSequencePtr>();
|
||||
MS_ASSERT(tuple->size() != 0);
|
||||
MS_EXCEPTION_IF_NULL(tuple);
|
||||
for (size_t i = 0; i < tuple->size(); ++i) {
|
||||
const auto &v = tuple->value()[i];
|
||||
AnfNodePtr special_like_value = BuildSpecialLikeValue(tape, v, type);
|
||||
|
@ -197,15 +243,6 @@ FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
|
|||
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
|
||||
return after_opt_bg;
|
||||
}
|
||||
|
||||
void ClearDeviceAddress(const ValuePtr &value) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(value, &tensors);
|
||||
for (auto tensor : tensors) {
|
||||
tensor->set_device_address(nullptr);
|
||||
tensor->set_is_forward_output(false);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
|
||||
|
@ -286,8 +323,7 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << grad_param->cnode->DebugString();
|
||||
}
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) ||
|
||||
IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
|
||||
if (!IsPrimNeedGrad(prim)) {
|
||||
MS_LOG(DEBUG) << "Prim " << prim->name() << " not need do op grad";
|
||||
return true;
|
||||
}
|
||||
|
@ -310,7 +346,7 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
} else {
|
||||
mindspore::BuildBprop(input_node, &outputs, &users_);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "The bprop output should not be empty" << grad_param->cnode->DebugString();
|
||||
MS_LOG(DEBUG) << "expander has no bprop of this prim: " << grad_param->cnode->DebugString();
|
||||
BuildCustomBpropCNode(input_node, &outputs);
|
||||
}
|
||||
}
|
||||
|
@ -321,10 +357,13 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
BuildCustomBpropCNode(input_node, &outputs);
|
||||
}
|
||||
#endif
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The bprop output should not be empty" << grad_param->cnode->DebugString();
|
||||
if (!outputs.empty()) {
|
||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "this op has not custom bprop: " << grad_param->cnode->DebugString();
|
||||
variable_adjoint->set_is_fake_bprop(true);
|
||||
variable_adjoint->set_fake_prim_name(prim->name());
|
||||
}
|
||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
// record last_node for brackpropagate
|
||||
last_node_ = grad_param->cnode;
|
||||
|
@ -416,7 +455,7 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
|
|||
MS_EXCEPTION_IF_NULL(sens_out);
|
||||
MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString();
|
||||
last_node_ = output_node;
|
||||
sens_value_ = sens_out;
|
||||
sens_value_ = FilterSensValues(sens_out);
|
||||
}
|
||||
|
||||
FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
|
@ -522,7 +561,6 @@ bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_
|
|||
void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode,
|
||||
const std::vector<CNodePtr> &dins, const ValuePtrList &op_args) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_ASSERT(dins.size() == op_args.size());
|
||||
if (dins.size() != op_args.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
|
||||
}
|
||||
|
@ -693,30 +731,6 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
|
|||
return;
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
|
||||
}
|
||||
|
||||
auto prim_py = prim->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_py);
|
||||
|
||||
auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
|
||||
(void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(fake_bprop)};
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
(void)inputs.emplace_back(cnode->input(i));
|
||||
}
|
||||
int index = cnode->size() - 1;
|
||||
auto dout = cnode->input(index);
|
||||
auto output = tape_->NewCNode(inputs);
|
||||
output->set_abstract(dout->abstract()->Broaden());
|
||||
(void)outputs->emplace_back(output);
|
||||
AddUser(dout, output, index);
|
||||
return;
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
|
@ -737,7 +751,6 @@ void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<
|
|||
}
|
||||
if (!fn || py::isinstance<py::none>(fn)) {
|
||||
MS_LOG(INFO) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
|
||||
BuildFakeBpropCNode(cnode, outputs);
|
||||
return;
|
||||
}
|
||||
prim_py->AddBackwardHookFn(0, fn);
|
||||
|
@ -793,7 +806,7 @@ OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetL
|
|||
}
|
||||
if (iter->first->cast<CNodePtr>() == last_node_) {
|
||||
auto &variable = anfnode_to_variable_adjoint_[last_node_];
|
||||
variable->set_is_need_grad(true);
|
||||
variable->set_is_need_propagate(true);
|
||||
return iter;
|
||||
}
|
||||
}
|
||||
|
@ -803,11 +816,14 @@ OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetL
|
|||
void AutoGradCellImpl::BackPropagate() {
|
||||
const auto &last_node_reverse_iter = GetLastNodeReverseIter();
|
||||
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
|
||||
MS_LOG(DEBUG) << "BackPropagate cnode " << iter->first->DebugString();
|
||||
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
|
||||
auto variable = iter->second;
|
||||
if (!variable->is_need_grad()) {
|
||||
if (!variable->is_need_propagate()) {
|
||||
continue;
|
||||
}
|
||||
if (variable->is_need_propagate() && variable->is_fake_bprop()) {
|
||||
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
|
||||
}
|
||||
auto fn = variable->fn();
|
||||
// replace real dout to fake dout
|
||||
Replace(fn->fake_dout(), fn->RealDout());
|
||||
|
@ -823,7 +839,7 @@ void AutoGradCellImpl::BackPropagate() {
|
|||
}
|
||||
auto last_variable = anfnode_to_variable_adjoint_[node];
|
||||
last_variable->fn()->UpdateAccumulativeDout(din);
|
||||
last_variable->set_is_need_grad(true);
|
||||
last_variable->set_is_need_propagate(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -987,7 +1003,6 @@ void AutoGradCellImpl::AddUser(const AnfNodePtr &node, const CNodePtr &user, siz
|
|||
|
||||
void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||
if (users_.find(old_node) == users_.end()) {
|
||||
MS_LOG(DEBUG) << "Can not find old node: " << old_node->DebugString();
|
||||
return;
|
||||
}
|
||||
auto &old_node_users = users_[old_node];
|
||||
|
@ -1021,15 +1036,6 @@ void AutoGradCellImpl::ElimateTupleGetItem() {
|
|||
}
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::ClearDeviceAddress(const ValuePtr &out) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(out, &tensors);
|
||||
for (auto tensor : tensors) {
|
||||
tensor->set_device_address(nullptr);
|
||||
tensor->set_is_forward_output(false);
|
||||
}
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
|
||||
const auto ¶meters = tape_->parameters();
|
||||
auto cell_inputs_size = cell_inputs_.size();
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <utility>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
|
@ -91,8 +92,12 @@ class VariableNode {
|
|||
|
||||
ValuePtr out_value() const { return out_value_; }
|
||||
FunctionNodePtr fn() const { return fn_; }
|
||||
bool is_need_grad() const { return is_need_grad_; }
|
||||
void set_is_need_grad(bool is_need_grad) { is_need_grad_ = is_need_grad; }
|
||||
const string &fake_prim_name() const { return fake_prim_name_; }
|
||||
void set_fake_prim_name(const string &fake_prim_name) { fake_prim_name_ = fake_prim_name; }
|
||||
bool is_fake_bprop() const { return is_fake_bprop_; }
|
||||
void set_is_fake_bprop(bool is_fake_bprop) { is_fake_bprop_ = is_fake_bprop; }
|
||||
bool is_need_propagate() const { return is_need_propagate_; }
|
||||
void set_is_need_propagate(bool is_need_grad) { is_need_propagate_ = is_need_grad; }
|
||||
AnfNodePtr k_node() const { return k_node_; }
|
||||
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
|
||||
|
||||
|
@ -100,8 +105,13 @@ class VariableNode {
|
|||
// Abstract bprop function
|
||||
FunctionNodePtr fn_;
|
||||
ValuePtr out_value_;
|
||||
bool is_need_grad_{false};
|
||||
// k mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
|
||||
// If node has not bprop, we record its prim name
|
||||
std::string fake_prim_name_;
|
||||
// Record this node is a fake bprop
|
||||
bool is_fake_bprop_{false};
|
||||
// Flag to judge need to propagrate
|
||||
bool is_need_propagate_{false};
|
||||
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
|
||||
AnfNodePtr k_node_{nullptr};
|
||||
};
|
||||
using VariableNodePtr = std::shared_ptr<VariableNode>;
|
||||
|
@ -179,7 +189,6 @@ class AutoGradCellImpl {
|
|||
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
|
||||
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
void ElimateTupleGetItem();
|
||||
void ClearDeviceAddress(const ValuePtr &out);
|
||||
|
||||
// Fbprop
|
||||
void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode);
|
||||
|
|
Loading…
Reference in New Issue