!15309 modify topk adapter for r1.2

From: @changzherui
Reviewed-by: @oacjiewen,@kingxian
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2021-04-17 12:55:27 +08:00 committed by Gitee
commit b869325005
2 changed files with 21 additions and 0 deletions

View File

@ -1445,6 +1445,20 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
tuple_out_handle_cache_[node.get()] = tuple_items;
}
void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
auto value_ptr = node->input(2)->cast<ValueNodePtr>();
std::ostringstream ss;
ss << "op" << value_ptr.get();
op_draw_name_[value_ptr.get()] = ss.str();
compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
auto int64_value = value_ptr->value()->cast<Int64ImmPtr>()->value();
OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
auto op = adpt->generate(value_ptr);
adpt->setAttr(op, "value", static_cast<int32_t>(int64_value));
op_cache_[value_ptr.get()] = op;
}
AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
const int TUPLE_GET_ITEM_INDEX = 2;
if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs
@ -1806,6 +1820,12 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
return false;
}
// Convert TopK second input from int64 to int32.
if (name == prim::kPrimTopK->name()) {
ConvertTopK(node);
return true;
}
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
if (name == prim::kPrimMakeTuple->name()) {
ConvertMakeTuple(node);

View File

@ -159,6 +159,7 @@ class DfGraphConvertor {
void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node);
void ConvertControlDependNode(const CNodePtr node);
void ConvertMakeTuple(const CNodePtr node);
void ConvertTopK(const CNodePtr node);
bool CheckCNode(const std::string &name, const CNodePtr node);
void TraceOutput(AnfNodePtr node);
void TraceOutputFromParameter(const AnfNodePtr &anf_out);