fix fake prim bug and others

This commit is contained in:
luochao 2022-11-30 20:23:19 +08:00
parent 3c800dc56f
commit b44cea3cf7
2 changed files with 80 additions and 65 deletions

View File

@ -41,17 +41,58 @@ constexpr char kAttrZerosLikeCOO[] = "zero_like_coo_node";
constexpr char kAttrOnesLikeCSR[] = "ones_like_csr_node";
constexpr char kAttrOnesLikeCOO[] = "ones_like_coo_node";
enum class SpecialType { kZerosLikeType = 0, kOnesLikeType = 1 };
const std::map<SpecialType, std::shared_ptr<Primitive>> value_type{{SpecialType::kZerosLikeType, prim::kPrimZerosLike},
const std::map<SpecialType, std::shared_ptr<Primitive>> kValueType{{SpecialType::kZerosLikeType, prim::kPrimZerosLike},
{SpecialType::kOnesLikeType, prim::kPrimOnesLike}};
const std::vector<PrimitivePtr> kGradBlackList{prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient,
prim::kPrimUpdateState, prim::kPrimNPUAllocFloatStatus};
AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value, const SpecialType &type);
void ClearDeviceAddress(const ValuePtr &value);
void ClearDeviceAddress(const ValuePtr &value) {
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(value, &tensors);
for (auto tensor : tensors) {
tensor->set_device_address(nullptr);
tensor->set_is_forward_output(false);
}
}
ValuePtr FilterSensValues(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>() || value->isa<tensor::COOTensor>() || value->isa<tensor::CSRTensor>()) {
return value;
} else if (value->isa<ValueSequence>()) {
std::vector<ValuePtr> value_list;
auto value_seq = value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(value_seq);
for (auto filter_value : value_seq->value()) {
if (FilterSensValues(filter_value) != nullptr) {
(void)value_list.emplace_back(filter_value);
}
}
return std::make_shared<ValueTuple>(value_list);
} else {
MS_LOG(DEBUG) << "value type: " << value->ToString();
return nullptr;
}
}
bool IsPrimNeedGrad(const PrimitivePtr &prim) {
for (const auto &no_need_grad_prim : kGradBlackList) {
if (IsPrimitiveEquals(prim, no_need_grad_prim)) {
return false;
}
}
return true;
}
AnfNodePtr BuildSpecialLikeCSRTensor(const FuncGraphPtr &tape, const ValuePtr &value, const SpecialType &type) {
MS_EXCEPTION_IF_NULL(value);
auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
MS_EXCEPTION_IF_NULL(csr_tensor);
auto indptr = csr_tensor->GetIndptr();
auto cloned_indptr = ShallowCopyTensorValue(indptr);
ClearDeviceAddress(cloned_indptr);
auto indptr_node = NewValueNode(cloned_indptr);
indptr_node->set_abstract(cloned_indptr->ToAbstract()->Broaden());
auto indices = csr_tensor->GetIndices();
@ -59,11 +100,13 @@ AnfNodePtr BuildSpecialLikeCSRTensor(const FuncGraphPtr &tape, const ValuePtr &v
ClearDeviceAddress(cloned_indices);
auto indices_node = NewValueNode(cloned_indices);
indices_node->set_abstract(cloned_indices->ToAbstract()->Broaden());
auto data = csr_tensor->GetValues();
auto cloned_data = ShallowCopyTensorValue(data);
ClearDeviceAddress(cloned_data);
auto value_node = NewValueNode(cloned_data);
value_node->set_abstract(cloned_data->ToAbstract()->Broaden());
auto zero_like_value = BuildSpecialLikeValue(tape, cloned_data, type);
auto shape = csr_tensor->shape();
auto value_shape = NewValueNode(shape);
@ -87,11 +130,13 @@ AnfNodePtr BuildSpecialLikeCSRTensor(const FuncGraphPtr &tape, const ValuePtr &v
AnfNodePtr BuildSpecialLikeCOOTensor(const FuncGraphPtr &tape, const ValuePtr &value, const SpecialType &type) {
MS_EXCEPTION_IF_NULL(value);
auto coo_tensor = value->cast<tensor::COOTensorPtr>();
MS_EXCEPTION_IF_NULL(coo_tensor);
auto indices = coo_tensor->GetIndices();
auto cloned_indices = ShallowCopyTensorValue(indices);
ClearDeviceAddress(cloned_indices);
auto indices_node = NewValueNode(cloned_indices);
indices_node->set_abstract(cloned_indices->ToAbstract()->Broaden());
auto data = coo_tensor->GetValues();
@ -99,6 +144,7 @@ AnfNodePtr BuildSpecialLikeCOOTensor(const FuncGraphPtr &tape, const ValuePtr &v
ClearDeviceAddress(cloned_data);
auto value_node = NewValueNode(cloned_data);
value_node->set_abstract(cloned_data->ToAbstract()->Broaden());
auto special_like_value = BuildSpecialLikeValue(tape, cloned_data, type);
auto shape = coo_tensor->shape();
auto value_shape = NewValueNode(shape);
@ -124,7 +170,7 @@ AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value
if (value->isa<tensor::Tensor>() || value->isa<Scalar>()) {
auto vlaue_node = NewValueNode(value);
vlaue_node->set_abstract(value->ToAbstract()->Broaden());
auto primitive = value_type.at(type);
auto primitive = kValueType.at(type);
MS_EXCEPTION_IF_NULL(primitive);
auto special_like_value = tape->NewCNode({NewValueNode(primitive), vlaue_node});
special_like_value->set_abstract(value->ToAbstract()->Broaden());
@ -137,7 +183,7 @@ AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value
std::vector<AnfNodePtr> args;
(void)args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
auto tuple = value->cast<ValueSequencePtr>();
MS_ASSERT(tuple->size() != 0);
MS_EXCEPTION_IF_NULL(tuple);
for (size_t i = 0; i < tuple->size(); ++i) {
const auto &v = tuple->value()[i];
AnfNodePtr special_like_value = BuildSpecialLikeValue(tape, v, type);
@ -197,15 +243,6 @@ FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
return after_opt_bg;
}
void ClearDeviceAddress(const ValuePtr &value) {
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(value, &tensors);
for (auto tensor : tensors) {
tensor->set_device_address(nullptr);
tensor->set_is_forward_output(false);
}
}
} // namespace
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
@ -286,8 +323,7 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Should be primitive, but: " << grad_param->cnode->DebugString();
}
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) ||
IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
if (!IsPrimNeedGrad(prim)) {
MS_LOG(DEBUG) << "Prim " << prim->name() << " not need do op grad";
return true;
}
@ -310,7 +346,7 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
} else {
mindspore::BuildBprop(input_node, &outputs, &users_);
if (outputs.empty()) {
MS_LOG(DEBUG) << "The bprop output should not be empty" << grad_param->cnode->DebugString();
MS_LOG(DEBUG) << "expander has no bprop of this prim: " << grad_param->cnode->DebugString();
BuildCustomBpropCNode(input_node, &outputs);
}
}
@ -321,10 +357,13 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
BuildCustomBpropCNode(input_node, &outputs);
}
#endif
if (outputs.empty()) {
MS_LOG(EXCEPTION) << "The bprop output should not be empty" << grad_param->cnode->DebugString();
if (!outputs.empty()) {
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
} else {
MS_LOG(DEBUG) << "this op has not custom bprop: " << grad_param->cnode->DebugString();
variable_adjoint->set_is_fake_bprop(true);
variable_adjoint->set_fake_prim_name(prim->name());
}
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
// record last_node for brackpropagate
last_node_ = grad_param->cnode;
@ -416,7 +455,7 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
MS_EXCEPTION_IF_NULL(sens_out);
MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString();
last_node_ = output_node;
sens_value_ = sens_out;
sens_value_ = FilterSensValues(sens_out);
}
FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
@ -522,7 +561,6 @@ bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_
void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode,
const std::vector<CNodePtr> &dins, const ValuePtrList &op_args) {
MS_EXCEPTION_IF_NULL(cnode);
MS_ASSERT(dins.size() == op_args.size());
if (dins.size() != op_args.size()) {
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
}
@ -693,30 +731,6 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
return;
}
void AutoGradCellImpl::BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
}
auto prim_py = prim->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(prim_py);
auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
(void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
std::vector<AnfNodePtr> inputs{NewValueNode(fake_bprop)};
for (size_t i = 1; i < cnode->size(); ++i) {
(void)inputs.emplace_back(cnode->input(i));
}
int index = cnode->size() - 1;
auto dout = cnode->input(index);
auto output = tape_->NewCNode(inputs);
output->set_abstract(dout->abstract()->Broaden());
(void)outputs->emplace_back(output);
AddUser(dout, output, index);
return;
}
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
@ -737,7 +751,6 @@ void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<
}
if (!fn || py::isinstance<py::none>(fn)) {
MS_LOG(INFO) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
BuildFakeBpropCNode(cnode, outputs);
return;
}
prim_py->AddBackwardHookFn(0, fn);
@ -793,7 +806,7 @@ OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetL
}
if (iter->first->cast<CNodePtr>() == last_node_) {
auto &variable = anfnode_to_variable_adjoint_[last_node_];
variable->set_is_need_grad(true);
variable->set_is_need_propagate(true);
return iter;
}
}
@ -803,11 +816,14 @@ OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetL
void AutoGradCellImpl::BackPropagate() {
const auto &last_node_reverse_iter = GetLastNodeReverseIter();
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
MS_LOG(DEBUG) << "BackPropagate cnode " << iter->first->DebugString();
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
auto variable = iter->second;
if (!variable->is_need_grad()) {
if (!variable->is_need_propagate()) {
continue;
}
if (variable->is_need_propagate() && variable->is_fake_bprop()) {
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
}
auto fn = variable->fn();
// replace real dout to fake dout
Replace(fn->fake_dout(), fn->RealDout());
@ -823,7 +839,7 @@ void AutoGradCellImpl::BackPropagate() {
}
auto last_variable = anfnode_to_variable_adjoint_[node];
last_variable->fn()->UpdateAccumulativeDout(din);
last_variable->set_is_need_grad(true);
last_variable->set_is_need_propagate(true);
}
}
}
@ -987,7 +1003,6 @@ void AutoGradCellImpl::AddUser(const AnfNodePtr &node, const CNodePtr &user, siz
void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
if (users_.find(old_node) == users_.end()) {
MS_LOG(DEBUG) << "Can not find old node: " << old_node->DebugString();
return;
}
auto &old_node_users = users_[old_node];
@ -1021,15 +1036,6 @@ void AutoGradCellImpl::ElimateTupleGetItem() {
}
}
void AutoGradCellImpl::ClearDeviceAddress(const ValuePtr &out) {
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(out, &tensors);
for (auto tensor : tensors) {
tensor->set_device_address(nullptr);
tensor->set_is_forward_output(false);
}
}
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
const auto &parameters = tape_->parameters();
auto cell_inputs_size = cell_inputs_.size();

View File

@ -21,6 +21,7 @@
#include <utility>
#include <map>
#include <vector>
#include <string>
#include "ir/anf.h"
#include "ir/func_graph.h"
@ -91,8 +92,12 @@ class VariableNode {
ValuePtr out_value() const { return out_value_; }
FunctionNodePtr fn() const { return fn_; }
bool is_need_grad() const { return is_need_grad_; }
void set_is_need_grad(bool is_need_grad) { is_need_grad_ = is_need_grad; }
const string &fake_prim_name() const { return fake_prim_name_; }
void set_fake_prim_name(const string &fake_prim_name) { fake_prim_name_ = fake_prim_name; }
bool is_fake_bprop() const { return is_fake_bprop_; }
void set_is_fake_bprop(bool is_fake_bprop) { is_fake_bprop_ = is_fake_bprop; }
bool is_need_propagate() const { return is_need_propagate_; }
void set_is_need_propagate(bool is_need_grad) { is_need_propagate_ = is_need_grad; }
AnfNodePtr k_node() const { return k_node_; }
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
@ -100,8 +105,13 @@ class VariableNode {
// Abstract bprop function
FunctionNodePtr fn_;
ValuePtr out_value_;
bool is_need_grad_{false};
// k mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
// If node has not bprop, we record its prim name
std::string fake_prim_name_;
// Record this node is a fake bprop
bool is_fake_bprop_{false};
// Flag to judge need to propagrate
bool is_need_propagate_{false};
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
AnfNodePtr k_node_{nullptr};
};
using VariableNodePtr = std::shared_ptr<VariableNode>;
@ -179,7 +189,6 @@ class AutoGradCellImpl {
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void ElimateTupleGetItem();
void ClearDeviceAddress(const ValuePtr &out);
// Fbprop
void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode);