diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc index 3decd4323eb..e57370bf85e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc @@ -22,25 +22,58 @@ namespace mindspore { namespace kernel { +using Tensor = mindspore::tensor::Tensor; +using TensorPtr = mindspore::tensor::TensorPtr; +using AbstractTensor = mindspore::abstract::AbstractTensor; +using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr; using CheckSupportFun = bool (*)(const CNodePtr &cnode); constexpr char kAttrStrides[] = "strides"; +constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask"; static bool CheckStridedSlice(const CNodePtr &cnode) { - // check stride[-1] != 1 TODO + // check stride[-1] != 1 if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) { auto strides = AnfAlgo::GetNodeAttr>(cnode, kAttrStrides); - if (!strides.empty() && strides[strides.size() - 1] == 1) { - return true; + if (!strides.empty() && strides[strides.size() - 1] != 1) { + return false; + } + } + // check reduction on the last dimension + if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) { + auto shrink_axis_mask = AnfAlgo::GetNodeAttr(cnode, kAttrShrinkAxisMask); + AnfNodePtr input = cnode->input(1); + int input_dims = 0; + if (input->isa()) { + ValuePtr input_value = input->cast()->value(); + if (!input_value->isa()) { + MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got " + << input_value->ToString(); + } + input_dims = SizeToInt(input_value->cast()->shape().size()); + } else if (input->isa() || input->isa()) { + AbstractBasePtr input_abstract = input->abstract(); + if (!input_abstract->isa()) { + MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got " + << input_abstract->ToString(); + } + input_dims = SizeToInt(input_abstract->cast()->shape()->shape().size()); + } else { + MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input node should be a 'ValueNode' or a 'CNode', but got " + << input->ToString(); + } + int base_number = 2; + if (shrink_axis_mask >= std::pow(base_number, input_dims - 1)) { + return false; } } - // last tensor TODO return true; } bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); - static std::map tbe_property_checker = {{parallel::KStridedSlice, CheckStridedSlice}}; + static std::map tbe_property_checker = {{kStridedSliceOpName, CheckStridedSlice}, + {kStridedSliceGradOpName, CheckStridedSlice}}; auto cnode_type = AnfAlgo::GetCNodeName(cnode); auto find_iter = tbe_property_checker.find(cnode_type); if (find_iter != tbe_property_checker.end()) { diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc index fdb6d6ae8f9..5ef4991e4c2 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -59,6 +59,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(kScatterNdOpName, {2}); Register(kStridedSliceAssignOpName, {1, 2, 3}); Register(kStridedSliceOpName, {1, 2, 3}); + Register(kStridedSliceGradOpName, {1, 2, 3, 4}); Register(kFlattenGradOpName, {1}); Register(kExpandDimsOpName, {1}); Register(kSplitOpName, {0}); diff --git a/mindspore/ops/_op_impl/aicpu/strided_slice.py b/mindspore/ops/_op_impl/aicpu/strided_slice.py index 0506e4104d6..d9e3124b0ea 100644 --- a/mindspore/ops/_op_impl/aicpu/strided_slice.py +++ b/mindspore/ops/_op_impl/aicpu/strided_slice.py @@ -16,25 +16,27 @@ """StridedSlice op""" from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType -strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \ +strided_slice_op_info = AiCPURegOp("StridedSlice") \ .fusion_type("OPAQUE") \ .input(0, "input", "required") \ - .input(1, "begin", "required") \ - .input(2, "end", "required") \ - .input(3, "stride", "required") \ .output(0, "output", "required") \ + .attr("begin", "listInt") \ + .attr("end", "listInt") \ + .attr("strides", "listInt") \ .attr("begin_mask", "int") \ .attr("end_mask", "int") \ .attr("ellipsis_mask", "int") \ .attr("new_axis_mask", "int") \ .attr("shrink_axis_mask", "int") \ - .dtype_format(DataType.F32_Default, - DataType.I32_Default, - DataType.I32_Default, - DataType.I32_Default, - DataType.F32_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ .get_op_info() + @op_info_register(strided_slice_op_info) def _strided_slice_aicpu(): """StridedSlice AiCPU register""" diff --git a/mindspore/ops/_op_impl/aicpu/strided_slice_grad.py b/mindspore/ops/_op_impl/aicpu/strided_slice_grad.py index b94c5d4c4bc..df9c2da53f4 100644 --- a/mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +++ b/mindspore/ops/_op_impl/aicpu/strided_slice_grad.py @@ -16,27 +16,28 @@ """StridedSliceGrad op""" from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType -strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \ +strided_slice_grad_op_info = AiCPURegOp("StridedSliceGrad") \ .fusion_type("OPAQUE") \ .input(0, "dy", "required") \ - .input(1, "shape", "required") \ - .input(2, "begin", "required") \ - .input(3, "end", "required") \ - .input(4, "stride", "required") \ .output(0, "output", "required") \ + .attr("shapex", "listInt") \ + .attr("begin", "listInt") \ + .attr("end", "listInt") \ + .attr("strides", "listInt") \ .attr("begin_mask", "int") \ .attr("end_mask", "int") \ .attr("ellipsis_mask", "int") \ .attr("new_axis_mask", "int") \ .attr("shrink_axis_mask", "int") \ - .dtype_format(DataType.F32_Default, - DataType.I32_Default, - DataType.I32_Default, - DataType.I32_Default, - DataType.I32_Default, - DataType.F32_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ .get_op_info() + @op_info_register(strided_slice_grad_op_info) def _strided_slice_grad_aicpu(): """StridedSliceGrad AiCPU register""" diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 250de74ef6c..16e2b1db101 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -915,13 +915,14 @@ test_case_math_ops = [ 'block': G.MinimumGrad(), 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]], 'skip': ['backward']}), - ('StridedSlice', { - 'block': P.StridedSlice(), + ('StridedSlice_00', { + 'block': P.StridedSlice(shrink_axis_mask=0), 'desc_const': [(0, 1, 2, 1), (2, 3, 3, 4), - (1, 1, 1, 1)], + (1, 1, 1, 2)], 'desc_inputs': [[2, 3, 3, 5]], - 'desc_bprop': [[2, 2, 1, 3]]}), + 'desc_bprop': [[2, 2, 1, 3]], + 'skip': ['backward']}), ('Slice_1', { 'block': P.Slice(), 'desc_const': [(0, 1, 2, 1),