insert transdata for pynative output

This commit is contained in:
yuchaojie 2022-03-16 16:56:47 +08:00
parent 016ec19a99
commit aae5cb4e31
2 changed files with 33 additions and 2 deletions

View File

@ -22,6 +22,33 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kNchwDimNum = 4;
constexpr size_t kDimC = 1;
bool IsDepthwiseCase(const CNodePtr &node, size_t index, const std::string &format, bool is_tuple) {
if (format != kOpFormat_FRAC_Z) {
return false;
}
abstract::BaseShapePtr base_shape = is_tuple ? common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index)
: common::AnfAlgo::GetOutputDetailShape(node, index);
MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>()) {
auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
auto shape_vec = shape_ptr->shape();
return shape_vec.size() == kNchwDimNum && shape_vec[kDimC] == 1;
}
return false;
}
bool NeedInsertTransDataForOutput(const CNodePtr &node, size_t index, bool is_tuple) {
const std::set<std::string> formats_need_transdata = {kOpFormat_ND_RNN_BIAS, kOpFormat_FRACTAL_ZN_RNN,
kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_ZN_LSTM};
auto format = is_tuple ? AnfAlgo::GetPrevNodeOutputFormat(node, index) : AnfAlgo::GetOutputFormat(node, index);
return formats_need_transdata.count(format) != 0 || IsDepthwiseCase(node, index, format, is_tuple);
}
} // namespace
bool RunOpInsertTransData::InsertTransdataForOutput(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
bool changed = false;
@ -32,8 +59,7 @@ bool RunOpInsertTransData::InsertTransdataForOutput(const FuncGraphPtr &graph) {
auto inputs_num = common::AnfAlgo::GetInputNum(cnode);
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
for (size_t index = 0; index < inputs_num; index++) {
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
if (format == kOpFormat_ND_RNN_BIAS || format == kOpFormat_FRACTAL_ZN_RNN) {
if (NeedInsertTransDataForOutput(cnode, index, true)) {
auto cur_cnode_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, index, false);
auto trans_node =
AddTransOpNodeToGraph(graph, cur_cnode_with_index.first, kernel_select_, cur_cnode_with_index.second, false);
@ -41,6 +67,10 @@ bool RunOpInsertTransData::InsertTransdataForOutput(const FuncGraphPtr &graph) {
has_changed = true;
}
}
} else if (!common::AnfAlgo::IsTupleOutput(cnode) && NeedInsertTransDataForOutput(cnode, 0, false)) {
auto trans_node = AddTransOpNodeToGraph(graph, cnode, kernel_select_, 0, false);
has_changed = true;
graph->set_output(trans_node);
}
if (has_changed) {

View File

@ -102,6 +102,7 @@ void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); });
node_address->set_from_persistent_mem(input_tensor->is_parameter());
node_address->SetNodeIndex(input_node, 0);
UpdateRefCount(node_address.get(), true);
}