!15725 [lite]fix train bug

From: @xu_anyue
Reviewed-by: @HilbertDavid,@jpc_chenjianping
Signed-off-by: @jpc_chenjianping
This commit is contained in:
mindspore-ci-bot 2021-04-27 18:45:45 +08:00 committed by Gitee
commit 7835f73fea
2 changed files with 126 additions and 60 deletions

View File

@ -110,6 +110,40 @@ STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, Sh
return RET_OK; return RET_OK;
} }
int FetchFromDefaultParam(const ParameterPtr &param_node, DataInfo *data_info) {
MS_ASSERT(param_node != nullptr && data_info != nullptr);
ShapeVector shape_vector;
TypeId data_type;
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
if (status != RET_OK) {
MS_LOG(ERROR) << "get data type and shape from param node failed.";
return RET_ERROR;
}
data_info->data_type_ = data_type;
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
size_t offset = 0;
if (!shape_vector.empty() && data_type == kObjectTypeString) {
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
if (status != RET_OK) {
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
return RET_ERROR;
}
}
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->shape_ = dims;
if (tensor_info != nullptr && tensor_info->Size() != 0) {
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_info->Size() - offset);
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}
}
}
return RET_OK;
}
int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type, int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
bool train_flag, DataInfo *data_info) { bool train_flag, DataInfo *data_info) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
@ -230,10 +264,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
DataInfo *data_info) { DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr); MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto param_node = cnode->input(index)->cast<ParameterPtr>(); auto param_node = cnode->input(index)->cast<ParameterPtr>();
if (param_node == nullptr) {
MS_LOG(ERROR) << "input node is not parameter node.";
return RET_ERROR;
}
data_info->format_ = GetFormatByFmk(fmk_type); data_info->format_ = GetFormatByFmk(fmk_type);
if (data_info->format_ < 0) { if (data_info->format_ < 0) {
MS_LOG(ERROR) << "don't support current fmk: " << fmk_type; MS_LOG(ERROR) << "don't support current fmk: " << fmk_type;
return lite::RET_ERROR; return RET_ERROR;
} }
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
@ -245,38 +283,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) { if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)); data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
} }
ShapeVector shape_vector; if (FetchFromDefaultParam(param_node, data_info) != RET_OK) {
TypeId data_type; MS_LOG(ERROR) << "fetch information from default param failed.";
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
if (status != RET_OK) {
MS_LOG(ERROR) << "get data type and shape from param node failed.";
return RET_ERROR; return RET_ERROR;
} }
data_info->data_type_ = data_type;
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
size_t offset = 0;
if (!shape_vector.empty() && data_type == kObjectTypeString) {
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
if (status != RET_OK) {
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
return RET_ERROR;
}
}
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->shape_ = dims;
if (tensor_info != nullptr && tensor_info->Size() != 0) {
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_info->Size() - offset);
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}
}
}
QuantParamHolderPtr quant_param_holder = QuantParamHolderPtr quant_param_holder =
prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && data_type == kNumberTypeInt8) { if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&
data_info->data_type_ == kNumberTypeInt8) {
data_info->enable_huffman_code_ = true; data_info->enable_huffman_code_ = true;
} }
data_info->node_type_ = NodeType_ValueNode; data_info->node_type_ = NodeType_ValueNode;
@ -287,6 +301,10 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
DataInfo *data_info) { DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr); MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto value_node = cnode->input(index)->cast<ValueNodePtr>(); auto value_node = cnode->input(index)->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "input node is not value node.";
return RET_ERROR;
}
auto value = value_node->value(); auto value = value_node->value();
int ret = RET_OK; int ret = RET_OK;
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));

View File

@ -18,26 +18,91 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "ops/depend.h"
#include "ops/make_tuple.h" #include "ops/make_tuple.h"
namespace mindspore::opt { namespace mindspore::opt {
namespace { namespace {
constexpr size_t kInputDoubleNum = 2; constexpr size_t kInputDoubleNum = 2;
constexpr size_t kInputTripleNum = 3; constexpr size_t kInputTripleNum = 3;
void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) { int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(anf_node != nullptr); MS_ASSERT(func_graph != nullptr && cnode != nullptr);
MS_ASSERT(inputs != nullptr); auto first_input = cnode->input(1);
auto cnode = anf_node->cast<CNodePtr>(); auto second_input = cnode->input(2);
if (cnode == nullptr) { AnfNodePtr must_monad = nullptr;
return; AnfNodePtr not_must_monad = nullptr;
} if (utils::isa<ValueNode>(first_input)) {
for (size_t i = 1; i < cnode->size(); ++i) { auto value_node = first_input->cast<ValueNodePtr>();
if (cnode->input(i)->isa<CNode>()) { MS_ASSERT(value_node->value() != nullptr);
inputs->push_back(cnode->input(i)); if (utils::isa<Monad>(value_node->value())) {
must_monad = first_input;
not_must_monad = second_input;
} }
} }
if (utils::isa<ValueNode>(second_input)) {
auto value_node = second_input->cast<ValueNodePtr>();
MS_ASSERT(value_node->value() != nullptr);
if (utils::isa<Monad>(value_node->value())) {
must_monad = second_input;
not_must_monad = first_input;
}
}
if (must_monad == nullptr) {
return lite::RET_NO_CHANGE;
}
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
if (!utils::isa<CNode>(not_must_monad) || CheckIsAllInputsParam(not_must_monad)) {
manager->Replace(cnode, must_monad);
} else {
manager->Replace(cnode, not_must_monad);
}
return lite::RET_OK;
}
int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool pre_node_is_first) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr pre_node = cnode->input(1);
AnfNodePtr post_node = cnode->input(2);
if (!pre_node_is_first) {
pre_node = cnode->input(2);
post_node = cnode->input(1);
}
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto node_users = manager->node_users()[pre_node];
auto iter =
std::find_if(node_users.begin(), node_users.end(),
[&post_node](const std::pair<AnfNodePtr, int> &post_pair) { return post_pair.first == post_node; });
if (iter == node_users.end()) {
return lite::RET_NO_CHANGE;
}
auto tr = manager->Transact();
tr.SetEdge(post_node, iter->second, NewValueNode(std::make_shared<UMonad>()));
tr.Commit();
auto depend_prim = std::make_shared<ops::Depend>();
auto depend_node = func_graph->NewCNode(depend_prim, {post_node, pre_node});
depend_node->set_fullname_with_scope(cnode->fullname_with_scope());
manager->Replace(cnode, depend_node);
return lite::RET_OK;
}
int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (ProcessDependencyWithTwoNodes(func_graph, cnode, true) == lite::RET_OK) {
return lite::RET_OK;
}
if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
return lite::RET_OK;
}
auto make_tuple_prim = NewValueNode(std::make_shared<ops::MakeTuple>());
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(cnode->input(0), make_tuple_prim);
return lite::RET_OK;
} }
} // namespace } // namespace
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) { if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode."; MS_LOG(DEBUG) << "anf node is node a cnode.";
@ -73,28 +138,11 @@ int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph,
return lite::RET_NO_CHANGE; return lite::RET_NO_CHANGE;
} }
auto cnode = anf_node->cast<CNodePtr>(); auto cnode = anf_node->cast<CNodePtr>();
auto inputs = cnode->inputs(); if (ProcessInputIsMonad(func_graph, cnode) == lite::RET_OK) {
std::vector<AnfNodePtr> new_inputs; return lite::RET_OK;
for (size_t i = 1; i < inputs.size(); ++i) {
if (!inputs[i]->isa<CNode>()) {
continue;
} }
if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) { // both of two inputs are not monad, but have dependency.
FetchCNodeFromMakeTuple(inputs[i], &new_inputs); return ProcessInputHaveDependency(func_graph, cnode);
continue;
}
new_inputs.push_back(inputs[i]);
}
for (auto &node : new_inputs) {
func_graph->get_return()->add_input(node);
}
auto value = std::make_shared<UMonad>();
bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value));
if (!replace_succ) {
MS_LOG(ERROR) << "replace redundant op failed.";
return lite::RET_ERROR;
}
return RET_OK;
} }
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {