forked from mindspore-Ecosystem/mindspore
!15725 [lite]fix train bug
From: @xu_anyue Reviewed-by: @HilbertDavid,@jpc_chenjianping Signed-off-by: @jpc_chenjianping
This commit is contained in:
commit
7835f73fea
|
@ -110,6 +110,40 @@ STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, Sh
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int FetchFromDefaultParam(const ParameterPtr ¶m_node, DataInfo *data_info) {
|
||||||
|
MS_ASSERT(param_node != nullptr && data_info != nullptr);
|
||||||
|
ShapeVector shape_vector;
|
||||||
|
TypeId data_type;
|
||||||
|
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "get data type and shape from param node failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
data_info->data_type_ = data_type;
|
||||||
|
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
|
||||||
|
size_t offset = 0;
|
||||||
|
if (!shape_vector.empty() && data_type == kObjectTypeString) {
|
||||||
|
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
||||||
|
data_info->shape_ = dims;
|
||||||
|
if (tensor_info != nullptr && tensor_info->Size() != 0) {
|
||||||
|
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
|
||||||
|
data_info->data_.resize(tensor_info->Size() - offset);
|
||||||
|
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
|
||||||
|
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
|
int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
|
||||||
bool train_flag, DataInfo *data_info) {
|
bool train_flag, DataInfo *data_info) {
|
||||||
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
|
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
|
||||||
|
@ -230,10 +264,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
||||||
DataInfo *data_info) {
|
DataInfo *data_info) {
|
||||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||||
auto param_node = cnode->input(index)->cast<ParameterPtr>();
|
auto param_node = cnode->input(index)->cast<ParameterPtr>();
|
||||||
|
if (param_node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "input node is not parameter node.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
data_info->format_ = GetFormatByFmk(fmk_type);
|
data_info->format_ = GetFormatByFmk(fmk_type);
|
||||||
if (data_info->format_ < 0) {
|
if (data_info->format_ < 0) {
|
||||||
MS_LOG(ERROR) << "don't support current fmk: " << fmk_type;
|
MS_LOG(ERROR) << "don't support current fmk: " << fmk_type;
|
||||||
return lite::RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
|
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
|
||||||
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
|
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
|
||||||
|
@ -245,38 +283,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
||||||
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
|
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
|
||||||
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
|
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
|
||||||
}
|
}
|
||||||
ShapeVector shape_vector;
|
if (FetchFromDefaultParam(param_node, data_info) != RET_OK) {
|
||||||
TypeId data_type;
|
MS_LOG(ERROR) << "fetch information from default param failed.";
|
||||||
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "get data type and shape from param node failed.";
|
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
data_info->data_type_ = data_type;
|
|
||||||
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
|
|
||||||
size_t offset = 0;
|
|
||||||
if (!shape_vector.empty() && data_type == kObjectTypeString) {
|
|
||||||
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
|
||||||
data_info->shape_ = dims;
|
|
||||||
if (tensor_info != nullptr && tensor_info->Size() != 0) {
|
|
||||||
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
|
|
||||||
data_info->data_.resize(tensor_info->Size() - offset);
|
|
||||||
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
|
|
||||||
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
|
|
||||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
QuantParamHolderPtr quant_param_holder =
|
QuantParamHolderPtr quant_param_holder =
|
||||||
prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
|
prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
|
||||||
if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && data_type == kNumberTypeInt8) {
|
if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&
|
||||||
|
data_info->data_type_ == kNumberTypeInt8) {
|
||||||
data_info->enable_huffman_code_ = true;
|
data_info->enable_huffman_code_ = true;
|
||||||
}
|
}
|
||||||
data_info->node_type_ = NodeType_ValueNode;
|
data_info->node_type_ = NodeType_ValueNode;
|
||||||
|
@ -287,6 +301,10 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
|
||||||
DataInfo *data_info) {
|
DataInfo *data_info) {
|
||||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||||
auto value_node = cnode->input(index)->cast<ValueNodePtr>();
|
auto value_node = cnode->input(index)->cast<ValueNodePtr>();
|
||||||
|
if (value_node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "input node is not value node.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
auto value = value_node->value();
|
auto value = value_node->value();
|
||||||
int ret = RET_OK;
|
int ret = RET_OK;
|
||||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
|
|
@ -18,26 +18,91 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
#include "ops/depend.h"
|
||||||
#include "ops/make_tuple.h"
|
#include "ops/make_tuple.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kInputDoubleNum = 2;
|
constexpr size_t kInputDoubleNum = 2;
|
||||||
constexpr size_t kInputTripleNum = 3;
|
constexpr size_t kInputTripleNum = 3;
|
||||||
void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) {
|
int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
MS_ASSERT(anf_node != nullptr);
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
MS_ASSERT(inputs != nullptr);
|
auto first_input = cnode->input(1);
|
||||||
auto cnode = anf_node->cast<CNodePtr>();
|
auto second_input = cnode->input(2);
|
||||||
if (cnode == nullptr) {
|
AnfNodePtr must_monad = nullptr;
|
||||||
return;
|
AnfNodePtr not_must_monad = nullptr;
|
||||||
}
|
if (utils::isa<ValueNode>(first_input)) {
|
||||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
auto value_node = first_input->cast<ValueNodePtr>();
|
||||||
if (cnode->input(i)->isa<CNode>()) {
|
MS_ASSERT(value_node->value() != nullptr);
|
||||||
inputs->push_back(cnode->input(i));
|
if (utils::isa<Monad>(value_node->value())) {
|
||||||
|
must_monad = first_input;
|
||||||
|
not_must_monad = second_input;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (utils::isa<ValueNode>(second_input)) {
|
||||||
|
auto value_node = second_input->cast<ValueNodePtr>();
|
||||||
|
MS_ASSERT(value_node->value() != nullptr);
|
||||||
|
if (utils::isa<Monad>(value_node->value())) {
|
||||||
|
must_monad = second_input;
|
||||||
|
not_must_monad = first_input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (must_monad == nullptr) {
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_ASSERT(manager != nullptr);
|
||||||
|
if (!utils::isa<CNode>(not_must_monad) || CheckIsAllInputsParam(not_must_monad)) {
|
||||||
|
manager->Replace(cnode, must_monad);
|
||||||
|
} else {
|
||||||
|
manager->Replace(cnode, not_must_monad);
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool pre_node_is_first) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
AnfNodePtr pre_node = cnode->input(1);
|
||||||
|
AnfNodePtr post_node = cnode->input(2);
|
||||||
|
if (!pre_node_is_first) {
|
||||||
|
pre_node = cnode->input(2);
|
||||||
|
post_node = cnode->input(1);
|
||||||
|
}
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_ASSERT(manager != nullptr);
|
||||||
|
auto node_users = manager->node_users()[pre_node];
|
||||||
|
auto iter =
|
||||||
|
std::find_if(node_users.begin(), node_users.end(),
|
||||||
|
[&post_node](const std::pair<AnfNodePtr, int> &post_pair) { return post_pair.first == post_node; });
|
||||||
|
if (iter == node_users.end()) {
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
auto tr = manager->Transact();
|
||||||
|
tr.SetEdge(post_node, iter->second, NewValueNode(std::make_shared<UMonad>()));
|
||||||
|
tr.Commit();
|
||||||
|
auto depend_prim = std::make_shared<ops::Depend>();
|
||||||
|
auto depend_node = func_graph->NewCNode(depend_prim, {post_node, pre_node});
|
||||||
|
depend_node->set_fullname_with_scope(cnode->fullname_with_scope());
|
||||||
|
manager->Replace(cnode, depend_node);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
if (ProcessDependencyWithTwoNodes(func_graph, cnode, true) == lite::RET_OK) {
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
auto make_tuple_prim = NewValueNode(std::make_shared<ops::MakeTuple>());
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_ASSERT(manager != nullptr);
|
||||||
|
manager->Replace(cnode->input(0), make_tuple_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||||
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
||||||
|
@ -73,28 +138,11 @@ int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph,
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
auto cnode = anf_node->cast<CNodePtr>();
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
auto inputs = cnode->inputs();
|
if (ProcessInputIsMonad(func_graph, cnode) == lite::RET_OK) {
|
||||||
std::vector<AnfNodePtr> new_inputs;
|
return lite::RET_OK;
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
|
||||||
if (!inputs[i]->isa<CNode>()) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) {
|
// both of two inputs are not monad, but have dependency.
|
||||||
FetchCNodeFromMakeTuple(inputs[i], &new_inputs);
|
return ProcessInputHaveDependency(func_graph, cnode);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
new_inputs.push_back(inputs[i]);
|
|
||||||
}
|
|
||||||
for (auto &node : new_inputs) {
|
|
||||||
func_graph->get_return()->add_input(node);
|
|
||||||
}
|
|
||||||
auto value = std::make_shared<UMonad>();
|
|
||||||
bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value));
|
|
||||||
if (!replace_succ) {
|
|
||||||
MS_LOG(ERROR) << "replace redundant op failed.";
|
|
||||||
return lite::RET_ERROR;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||||
|
|
Loading…
Reference in New Issue