forked from mindspore-Ecosystem/mindspore
insert transdata for pynative output
This commit is contained in:
parent
016ec19a99
commit
aae5cb4e31
|
@ -22,6 +22,33 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kNchwDimNum = 4;
|
||||||
|
constexpr size_t kDimC = 1;
|
||||||
|
|
||||||
|
bool IsDepthwiseCase(const CNodePtr &node, size_t index, const std::string &format, bool is_tuple) {
|
||||||
|
if (format != kOpFormat_FRAC_Z) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
abstract::BaseShapePtr base_shape = is_tuple ? common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index)
|
||||||
|
: common::AnfAlgo::GetOutputDetailShape(node, index);
|
||||||
|
MS_EXCEPTION_IF_NULL(base_shape);
|
||||||
|
if (base_shape->isa<abstract::Shape>()) {
|
||||||
|
auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
|
||||||
|
auto shape_vec = shape_ptr->shape();
|
||||||
|
return shape_vec.size() == kNchwDimNum && shape_vec[kDimC] == 1;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NeedInsertTransDataForOutput(const CNodePtr &node, size_t index, bool is_tuple) {
|
||||||
|
const std::set<std::string> formats_need_transdata = {kOpFormat_ND_RNN_BIAS, kOpFormat_FRACTAL_ZN_RNN,
|
||||||
|
kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_ZN_LSTM};
|
||||||
|
auto format = is_tuple ? AnfAlgo::GetPrevNodeOutputFormat(node, index) : AnfAlgo::GetOutputFormat(node, index);
|
||||||
|
return formats_need_transdata.count(format) != 0 || IsDepthwiseCase(node, index, format, is_tuple);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
bool RunOpInsertTransData::InsertTransdataForOutput(const FuncGraphPtr &graph) {
|
bool RunOpInsertTransData::InsertTransdataForOutput(const FuncGraphPtr &graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
@ -32,8 +59,7 @@ bool RunOpInsertTransData::InsertTransdataForOutput(const FuncGraphPtr &graph) {
|
||||||
auto inputs_num = common::AnfAlgo::GetInputNum(cnode);
|
auto inputs_num = common::AnfAlgo::GetInputNum(cnode);
|
||||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
|
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
|
||||||
for (size_t index = 0; index < inputs_num; index++) {
|
for (size_t index = 0; index < inputs_num; index++) {
|
||||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
|
if (NeedInsertTransDataForOutput(cnode, index, true)) {
|
||||||
if (format == kOpFormat_ND_RNN_BIAS || format == kOpFormat_FRACTAL_ZN_RNN) {
|
|
||||||
auto cur_cnode_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, index, false);
|
auto cur_cnode_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, index, false);
|
||||||
auto trans_node =
|
auto trans_node =
|
||||||
AddTransOpNodeToGraph(graph, cur_cnode_with_index.first, kernel_select_, cur_cnode_with_index.second, false);
|
AddTransOpNodeToGraph(graph, cur_cnode_with_index.first, kernel_select_, cur_cnode_with_index.second, false);
|
||||||
|
@ -41,6 +67,10 @@ bool RunOpInsertTransData::InsertTransdataForOutput(const FuncGraphPtr &graph) {
|
||||||
has_changed = true;
|
has_changed = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (!common::AnfAlgo::IsTupleOutput(cnode) && NeedInsertTransDataForOutput(cnode, 0, false)) {
|
||||||
|
auto trans_node = AddTransOpNodeToGraph(graph, cnode, kernel_select_, 0, false);
|
||||||
|
has_changed = true;
|
||||||
|
graph->set_output(trans_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_changed) {
|
if (has_changed) {
|
||||||
|
|
|
@ -102,6 +102,7 @@ void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
|
||||||
input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
|
input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
|
||||||
input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); });
|
input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); });
|
||||||
node_address->set_from_persistent_mem(input_tensor->is_parameter());
|
node_address->set_from_persistent_mem(input_tensor->is_parameter());
|
||||||
|
node_address->SetNodeIndex(input_node, 0);
|
||||||
UpdateRefCount(node_address.get(), true);
|
UpdateRefCount(node_address.get(), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue