forked from OSSInnovation/mindspore
!1582 add topk and randomchoicewithmask op data type for aicpu
Merge pull request !1582 from yanzhenxiang2020/r03_add_datatype
This commit is contained in:
commit
b70b2da675
|
@ -25,6 +25,7 @@ random_choice_with_mask_op_info = AiCPURegOp("RandomChoiceWithMask") \
|
||||||
.attr("seed", "int") \
|
.attr("seed", "int") \
|
||||||
.attr("seed2", "int") \
|
.attr("seed2", "int") \
|
||||||
.dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \
|
.dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
@op_info_register(random_choice_with_mask_op_info)
|
@op_info_register(random_choice_with_mask_op_info)
|
||||||
|
|
|
@ -24,6 +24,7 @@ top_k_op_info = AiCPURegOp("TopK") \
|
||||||
.output(0, "values", "required") \
|
.output(0, "values", "required") \
|
||||||
.output(1, "indices", "required") \
|
.output(1, "indices", "required") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
@op_info_register(top_k_op_info)
|
@op_info_register(top_k_op_info)
|
||||||
|
|
Loading…
Reference in New Issue