forked from mindspore-Ecosystem/mindspore
!3317 add check for stridedslice when choose aicpu or aicore
Merge pull request !3317 from zhangbuxue/add_check_for_stridedslice_when_choose_aicpu_or_aicore
This commit is contained in:
commit
b83d921735
|
@ -22,25 +22,58 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using Tensor = mindspore::tensor::Tensor;
|
||||
using TensorPtr = mindspore::tensor::TensorPtr;
|
||||
using AbstractTensor = mindspore::abstract::AbstractTensor;
|
||||
using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr;
|
||||
using CheckSupportFun = bool (*)(const CNodePtr &cnode);
|
||||
|
||||
constexpr char kAttrStrides[] = "strides";
|
||||
constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask";
|
||||
|
||||
static bool CheckStridedSlice(const CNodePtr &cnode) {
|
||||
// check stride[-1] != 1 TODO
|
||||
// check stride[-1] != 1
|
||||
if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) {
|
||||
auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides);
|
||||
if (!strides.empty() && strides[strides.size() - 1] == 1) {
|
||||
return true;
|
||||
if (!strides.empty() && strides[strides.size() - 1] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// check reduction on the last dimension
|
||||
if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) {
|
||||
auto shrink_axis_mask = AnfAlgo::GetNodeAttr<int>(cnode, kAttrShrinkAxisMask);
|
||||
AnfNodePtr input = cnode->input(1);
|
||||
int input_dims = 0;
|
||||
if (input->isa<ValueNode>()) {
|
||||
ValuePtr input_value = input->cast<ValueNodePtr>()->value();
|
||||
if (!input_value->isa<Tensor>()) {
|
||||
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got "
|
||||
<< input_value->ToString();
|
||||
}
|
||||
input_dims = SizeToInt(input_value->cast<TensorPtr>()->shape().size());
|
||||
} else if (input->isa<CNode>() || input->isa<Parameter>()) {
|
||||
AbstractBasePtr input_abstract = input->abstract();
|
||||
if (!input_abstract->isa<AbstractTensor>()) {
|
||||
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got "
|
||||
<< input_abstract->ToString();
|
||||
}
|
||||
input_dims = SizeToInt(input_abstract->cast<AbstractTensorPtr>()->shape()->shape().size());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input node should be a 'ValueNode' or a 'CNode', but got "
|
||||
<< input->ToString();
|
||||
}
|
||||
int base_number = 2;
|
||||
if (shrink_axis_mask >= std::pow<int, int>(base_number, input_dims - 1)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// last tensor TODO
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
static std::map<std::string, CheckSupportFun> tbe_property_checker = {{parallel::KStridedSlice, CheckStridedSlice}};
|
||||
static std::map<std::string, CheckSupportFun> tbe_property_checker = {{kStridedSliceOpName, CheckStridedSlice},
|
||||
{kStridedSliceGradOpName, CheckStridedSlice}};
|
||||
auto cnode_type = AnfAlgo::GetCNodeName(cnode);
|
||||
auto find_iter = tbe_property_checker.find(cnode_type);
|
||||
if (find_iter != tbe_property_checker.end()) {
|
||||
|
|
|
@ -59,6 +59,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(kScatterNdOpName, {2});
|
||||
Register(kStridedSliceAssignOpName, {1, 2, 3});
|
||||
Register(kStridedSliceOpName, {1, 2, 3});
|
||||
Register(kStridedSliceGradOpName, {1, 2, 3, 4});
|
||||
Register(kFlattenGradOpName, {1});
|
||||
Register(kExpandDimsOpName, {1});
|
||||
Register(kSplitOpName, {0});
|
||||
|
|
|
@ -16,25 +16,27 @@
|
|||
"""StridedSlice op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \
|
||||
strided_slice_op_info = AiCPURegOp("StridedSlice") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "input", "required") \
|
||||
.input(1, "begin", "required") \
|
||||
.input(2, "end", "required") \
|
||||
.input(3, "stride", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.attr("begin", "listInt") \
|
||||
.attr("end", "listInt") \
|
||||
.attr("strides", "listInt") \
|
||||
.attr("begin_mask", "int") \
|
||||
.attr("end_mask", "int") \
|
||||
.attr("ellipsis_mask", "int") \
|
||||
.attr("new_axis_mask", "int") \
|
||||
.attr("shrink_axis_mask", "int") \
|
||||
.dtype_format(DataType.F32_Default,
|
||||
DataType.I32_Default,
|
||||
DataType.I32_Default,
|
||||
DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(strided_slice_op_info)
|
||||
def _strided_slice_aicpu():
|
||||
"""StridedSlice AiCPU register"""
|
||||
|
|
|
@ -16,27 +16,28 @@
|
|||
"""StridedSliceGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \
|
||||
strided_slice_grad_op_info = AiCPURegOp("StridedSliceGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "dy", "required") \
|
||||
.input(1, "shape", "required") \
|
||||
.input(2, "begin", "required") \
|
||||
.input(3, "end", "required") \
|
||||
.input(4, "stride", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.attr("shapex", "listInt") \
|
||||
.attr("begin", "listInt") \
|
||||
.attr("end", "listInt") \
|
||||
.attr("strides", "listInt") \
|
||||
.attr("begin_mask", "int") \
|
||||
.attr("end_mask", "int") \
|
||||
.attr("ellipsis_mask", "int") \
|
||||
.attr("new_axis_mask", "int") \
|
||||
.attr("shrink_axis_mask", "int") \
|
||||
.dtype_format(DataType.F32_Default,
|
||||
DataType.I32_Default,
|
||||
DataType.I32_Default,
|
||||
DataType.I32_Default,
|
||||
DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(strided_slice_grad_op_info)
|
||||
def _strided_slice_grad_aicpu():
|
||||
"""StridedSliceGrad AiCPU register"""
|
||||
|
|
|
@ -915,13 +915,14 @@ test_case_math_ops = [
|
|||
'block': G.MinimumGrad(),
|
||||
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]],
|
||||
'skip': ['backward']}),
|
||||
('StridedSlice', {
|
||||
'block': P.StridedSlice(),
|
||||
('StridedSlice_00', {
|
||||
'block': P.StridedSlice(shrink_axis_mask=0),
|
||||
'desc_const': [(0, 1, 2, 1),
|
||||
(2, 3, 3, 4),
|
||||
(1, 1, 1, 1)],
|
||||
(1, 1, 1, 2)],
|
||||
'desc_inputs': [[2, 3, 3, 5]],
|
||||
'desc_bprop': [[2, 2, 1, 3]]}),
|
||||
'desc_bprop': [[2, 2, 1, 3]],
|
||||
'skip': ['backward']}),
|
||||
('Slice_1', {
|
||||
'block': P.Slice(),
|
||||
'desc_const': [(0, 1, 2, 1),
|
||||
|
|
Loading…
Reference in New Issue