!15309 modify topk adapter for r1.2
From: @changzherui Reviewed-by: @oacjiewen,@kingxian Signed-off-by: @kingxian
This commit is contained in:
commit
b869325005
|
@ -1445,6 +1445,20 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
|
|||
tuple_out_handle_cache_[node.get()] = tuple_items;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
|
||||
MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
|
||||
auto value_ptr = node->input(2)->cast<ValueNodePtr>();
|
||||
std::ostringstream ss;
|
||||
ss << "op" << value_ptr.get();
|
||||
op_draw_name_[value_ptr.get()] = ss.str();
|
||||
compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
|
||||
auto int64_value = value_ptr->value()->cast<Int64ImmPtr>()->value();
|
||||
OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
|
||||
auto op = adpt->generate(value_ptr);
|
||||
adpt->setAttr(op, "value", static_cast<int32_t>(int64_value));
|
||||
op_cache_[value_ptr.get()] = op;
|
||||
}
|
||||
|
||||
AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
|
||||
const int TUPLE_GET_ITEM_INDEX = 2;
|
||||
if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs
|
||||
|
@ -1806,6 +1820,12 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
|||
return false;
|
||||
}
|
||||
|
||||
// Convert TopK second input from int64 to int32.
|
||||
if (name == prim::kPrimTopK->name()) {
|
||||
ConvertTopK(node);
|
||||
return true;
|
||||
}
|
||||
|
||||
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
|
||||
if (name == prim::kPrimMakeTuple->name()) {
|
||||
ConvertMakeTuple(node);
|
||||
|
|
|
@ -159,6 +159,7 @@ class DfGraphConvertor {
|
|||
void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node);
|
||||
void ConvertControlDependNode(const CNodePtr node);
|
||||
void ConvertMakeTuple(const CNodePtr node);
|
||||
void ConvertTopK(const CNodePtr node);
|
||||
bool CheckCNode(const std::string &name, const CNodePtr node);
|
||||
void TraceOutput(AnfNodePtr node);
|
||||
void TraceOutputFromParameter(const AnfNodePtr &anf_out);
|
||||
|
|
Loading…
Reference in New Issue