!3317 add check for stridedslice when choose aicpu or aicore

Merge pull request !3317 from zhangbuxue/add_check_for_stridedslice_when_choose_aicpu_or_aicore
This commit is contained in:
mindspore-ci-bot 2020-09-15 10:37:09 +08:00 committed by Gitee
commit b83d921735
5 changed files with 67 additions and 29 deletions

View File

@ -22,25 +22,58 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
using Tensor = mindspore::tensor::Tensor;
using TensorPtr = mindspore::tensor::TensorPtr;
using AbstractTensor = mindspore::abstract::AbstractTensor;
using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr;
using CheckSupportFun = bool (*)(const CNodePtr &cnode); using CheckSupportFun = bool (*)(const CNodePtr &cnode);
constexpr char kAttrStrides[] = "strides"; constexpr char kAttrStrides[] = "strides";
constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask";
static bool CheckStridedSlice(const CNodePtr &cnode) { static bool CheckStridedSlice(const CNodePtr &cnode) {
// check stride[-1] != 1 TODO // check stride[-1] != 1
if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) { if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) {
auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides); auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides);
if (!strides.empty() && strides[strides.size() - 1] == 1) { if (!strides.empty() && strides[strides.size() - 1] != 1) {
return true; return false;
}
}
// check reduction on the last dimension
if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) {
auto shrink_axis_mask = AnfAlgo::GetNodeAttr<int>(cnode, kAttrShrinkAxisMask);
AnfNodePtr input = cnode->input(1);
int input_dims = 0;
if (input->isa<ValueNode>()) {
ValuePtr input_value = input->cast<ValueNodePtr>()->value();
if (!input_value->isa<Tensor>()) {
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got "
<< input_value->ToString();
}
input_dims = SizeToInt(input_value->cast<TensorPtr>()->shape().size());
} else if (input->isa<CNode>() || input->isa<Parameter>()) {
AbstractBasePtr input_abstract = input->abstract();
if (!input_abstract->isa<AbstractTensor>()) {
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got "
<< input_abstract->ToString();
}
input_dims = SizeToInt(input_abstract->cast<AbstractTensorPtr>()->shape()->shape().size());
} else {
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input node should be a 'ValueNode' or a 'CNode', but got "
<< input->ToString();
}
int base_number = 2;
if (shrink_axis_mask >= std::pow<int, int>(base_number, input_dims - 1)) {
return false;
} }
} }
// last tensor TODO
return true; return true;
} }
bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) { bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
static std::map<std::string, CheckSupportFun> tbe_property_checker = {{parallel::KStridedSlice, CheckStridedSlice}}; static std::map<std::string, CheckSupportFun> tbe_property_checker = {{kStridedSliceOpName, CheckStridedSlice},
{kStridedSliceGradOpName, CheckStridedSlice}};
auto cnode_type = AnfAlgo::GetCNodeName(cnode); auto cnode_type = AnfAlgo::GetCNodeName(cnode);
auto find_iter = tbe_property_checker.find(cnode_type); auto find_iter = tbe_property_checker.find(cnode_type);
if (find_iter != tbe_property_checker.end()) { if (find_iter != tbe_property_checker.end()) {

View File

@ -59,6 +59,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(kScatterNdOpName, {2}); Register(kScatterNdOpName, {2});
Register(kStridedSliceAssignOpName, {1, 2, 3}); Register(kStridedSliceAssignOpName, {1, 2, 3});
Register(kStridedSliceOpName, {1, 2, 3}); Register(kStridedSliceOpName, {1, 2, 3});
Register(kStridedSliceGradOpName, {1, 2, 3, 4});
Register(kFlattenGradOpName, {1}); Register(kFlattenGradOpName, {1});
Register(kExpandDimsOpName, {1}); Register(kExpandDimsOpName, {1});
Register(kSplitOpName, {0}); Register(kSplitOpName, {0});

View File

@ -16,25 +16,27 @@
"""StridedSlice op""" """StridedSlice op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \ strided_slice_op_info = AiCPURegOp("StridedSlice") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "input", "required") \ .input(0, "input", "required") \
.input(1, "begin", "required") \
.input(2, "end", "required") \
.input(3, "stride", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("begin", "listInt") \
.attr("end", "listInt") \
.attr("strides", "listInt") \
.attr("begin_mask", "int") \ .attr("begin_mask", "int") \
.attr("end_mask", "int") \ .attr("end_mask", "int") \
.attr("ellipsis_mask", "int") \ .attr("ellipsis_mask", "int") \
.attr("new_axis_mask", "int") \ .attr("new_axis_mask", "int") \
.attr("shrink_axis_mask", "int") \ .attr("shrink_axis_mask", "int") \
.dtype_format(DataType.F32_Default, .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
DataType.I32_Default, .dtype_format(DataType.I8_Default, DataType.I8_Default) \
DataType.I32_Default, .dtype_format(DataType.U8_Default, DataType.U8_Default) \
DataType.I32_Default, .dtype_format(DataType.I32_Default, DataType.I32_Default) \
DataType.F32_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info() .get_op_info()
@op_info_register(strided_slice_op_info) @op_info_register(strided_slice_op_info)
def _strided_slice_aicpu(): def _strided_slice_aicpu():
"""StridedSlice AiCPU register""" """StridedSlice AiCPU register"""

View File

@ -16,27 +16,28 @@
"""StridedSliceGrad op""" """StridedSliceGrad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \ strided_slice_grad_op_info = AiCPURegOp("StridedSliceGrad") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "dy", "required") \ .input(0, "dy", "required") \
.input(1, "shape", "required") \
.input(2, "begin", "required") \
.input(3, "end", "required") \
.input(4, "stride", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("shapex", "listInt") \
.attr("begin", "listInt") \
.attr("end", "listInt") \
.attr("strides", "listInt") \
.attr("begin_mask", "int") \ .attr("begin_mask", "int") \
.attr("end_mask", "int") \ .attr("end_mask", "int") \
.attr("ellipsis_mask", "int") \ .attr("ellipsis_mask", "int") \
.attr("new_axis_mask", "int") \ .attr("new_axis_mask", "int") \
.attr("shrink_axis_mask", "int") \ .attr("shrink_axis_mask", "int") \
.dtype_format(DataType.F32_Default, .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
DataType.I32_Default, .dtype_format(DataType.I8_Default, DataType.I8_Default) \
DataType.I32_Default, .dtype_format(DataType.U8_Default, DataType.U8_Default) \
DataType.I32_Default, .dtype_format(DataType.I32_Default, DataType.I32_Default) \
DataType.I32_Default, .dtype_format(DataType.F16_Default, DataType.F16_Default) \
DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info() .get_op_info()
@op_info_register(strided_slice_grad_op_info) @op_info_register(strided_slice_grad_op_info)
def _strided_slice_grad_aicpu(): def _strided_slice_grad_aicpu():
"""StridedSliceGrad AiCPU register""" """StridedSliceGrad AiCPU register"""

View File

@ -915,13 +915,14 @@ test_case_math_ops = [
'block': G.MinimumGrad(), 'block': G.MinimumGrad(),
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]],
'skip': ['backward']}), 'skip': ['backward']}),
('StridedSlice', { ('StridedSlice_00', {
'block': P.StridedSlice(), 'block': P.StridedSlice(shrink_axis_mask=0),
'desc_const': [(0, 1, 2, 1), 'desc_const': [(0, 1, 2, 1),
(2, 3, 3, 4), (2, 3, 3, 4),
(1, 1, 1, 1)], (1, 1, 1, 2)],
'desc_inputs': [[2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5]],
'desc_bprop': [[2, 2, 1, 3]]}), 'desc_bprop': [[2, 2, 1, 3]],
'skip': ['backward']}),
('Slice_1', { ('Slice_1', {
'block': P.Slice(), 'block': P.Slice(),
'desc_const': [(0, 1, 2, 1), 'desc_const': [(0, 1, 2, 1),