forked from mindspore-Ecosystem/mindspore
!3317 add check for stridedslice when choose aicpu or aicore
Merge pull request !3317 from zhangbuxue/add_check_for_stridedslice_when_choose_aicpu_or_aicore
This commit is contained in:
commit
b83d921735
|
@ -22,25 +22,58 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
using Tensor = mindspore::tensor::Tensor;
|
||||||
|
using TensorPtr = mindspore::tensor::TensorPtr;
|
||||||
|
using AbstractTensor = mindspore::abstract::AbstractTensor;
|
||||||
|
using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr;
|
||||||
using CheckSupportFun = bool (*)(const CNodePtr &cnode);
|
using CheckSupportFun = bool (*)(const CNodePtr &cnode);
|
||||||
|
|
||||||
constexpr char kAttrStrides[] = "strides";
|
constexpr char kAttrStrides[] = "strides";
|
||||||
|
constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask";
|
||||||
|
|
||||||
static bool CheckStridedSlice(const CNodePtr &cnode) {
|
static bool CheckStridedSlice(const CNodePtr &cnode) {
|
||||||
// check stride[-1] != 1 TODO
|
// check stride[-1] != 1
|
||||||
if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) {
|
if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) {
|
||||||
auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides);
|
auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides);
|
||||||
if (!strides.empty() && strides[strides.size() - 1] == 1) {
|
if (!strides.empty() && strides[strides.size() - 1] != 1) {
|
||||||
return true;
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check reduction on the last dimension
|
||||||
|
if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) {
|
||||||
|
auto shrink_axis_mask = AnfAlgo::GetNodeAttr<int>(cnode, kAttrShrinkAxisMask);
|
||||||
|
AnfNodePtr input = cnode->input(1);
|
||||||
|
int input_dims = 0;
|
||||||
|
if (input->isa<ValueNode>()) {
|
||||||
|
ValuePtr input_value = input->cast<ValueNodePtr>()->value();
|
||||||
|
if (!input_value->isa<Tensor>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got "
|
||||||
|
<< input_value->ToString();
|
||||||
|
}
|
||||||
|
input_dims = SizeToInt(input_value->cast<TensorPtr>()->shape().size());
|
||||||
|
} else if (input->isa<CNode>() || input->isa<Parameter>()) {
|
||||||
|
AbstractBasePtr input_abstract = input->abstract();
|
||||||
|
if (!input_abstract->isa<AbstractTensor>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got "
|
||||||
|
<< input_abstract->ToString();
|
||||||
|
}
|
||||||
|
input_dims = SizeToInt(input_abstract->cast<AbstractTensorPtr>()->shape()->shape().size());
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input node should be a 'ValueNode' or a 'CNode', but got "
|
||||||
|
<< input->ToString();
|
||||||
|
}
|
||||||
|
int base_number = 2;
|
||||||
|
if (shrink_axis_mask >= std::pow<int, int>(base_number, input_dims - 1)) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// last tensor TODO
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) {
|
bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
static std::map<std::string, CheckSupportFun> tbe_property_checker = {{parallel::KStridedSlice, CheckStridedSlice}};
|
static std::map<std::string, CheckSupportFun> tbe_property_checker = {{kStridedSliceOpName, CheckStridedSlice},
|
||||||
|
{kStridedSliceGradOpName, CheckStridedSlice}};
|
||||||
auto cnode_type = AnfAlgo::GetCNodeName(cnode);
|
auto cnode_type = AnfAlgo::GetCNodeName(cnode);
|
||||||
auto find_iter = tbe_property_checker.find(cnode_type);
|
auto find_iter = tbe_property_checker.find(cnode_type);
|
||||||
if (find_iter != tbe_property_checker.end()) {
|
if (find_iter != tbe_property_checker.end()) {
|
||||||
|
|
|
@ -59,6 +59,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
||||||
Register(kScatterNdOpName, {2});
|
Register(kScatterNdOpName, {2});
|
||||||
Register(kStridedSliceAssignOpName, {1, 2, 3});
|
Register(kStridedSliceAssignOpName, {1, 2, 3});
|
||||||
Register(kStridedSliceOpName, {1, 2, 3});
|
Register(kStridedSliceOpName, {1, 2, 3});
|
||||||
|
Register(kStridedSliceGradOpName, {1, 2, 3, 4});
|
||||||
Register(kFlattenGradOpName, {1});
|
Register(kFlattenGradOpName, {1});
|
||||||
Register(kExpandDimsOpName, {1});
|
Register(kExpandDimsOpName, {1});
|
||||||
Register(kSplitOpName, {0});
|
Register(kSplitOpName, {0});
|
||||||
|
|
|
@ -16,25 +16,27 @@
|
||||||
"""StridedSlice op"""
|
"""StridedSlice op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \
|
strided_slice_op_info = AiCPURegOp("StridedSlice") \
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
.input(0, "input", "required") \
|
.input(0, "input", "required") \
|
||||||
.input(1, "begin", "required") \
|
|
||||||
.input(2, "end", "required") \
|
|
||||||
.input(3, "stride", "required") \
|
|
||||||
.output(0, "output", "required") \
|
.output(0, "output", "required") \
|
||||||
|
.attr("begin", "listInt") \
|
||||||
|
.attr("end", "listInt") \
|
||||||
|
.attr("strides", "listInt") \
|
||||||
.attr("begin_mask", "int") \
|
.attr("begin_mask", "int") \
|
||||||
.attr("end_mask", "int") \
|
.attr("end_mask", "int") \
|
||||||
.attr("ellipsis_mask", "int") \
|
.attr("ellipsis_mask", "int") \
|
||||||
.attr("new_axis_mask", "int") \
|
.attr("new_axis_mask", "int") \
|
||||||
.attr("shrink_axis_mask", "int") \
|
.attr("shrink_axis_mask", "int") \
|
||||||
.dtype_format(DataType.F32_Default,
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
DataType.I32_Default,
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
DataType.I32_Default,
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
DataType.I32_Default,
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
DataType.F32_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(strided_slice_op_info)
|
@op_info_register(strided_slice_op_info)
|
||||||
def _strided_slice_aicpu():
|
def _strided_slice_aicpu():
|
||||||
"""StridedSlice AiCPU register"""
|
"""StridedSlice AiCPU register"""
|
||||||
|
|
|
@ -16,27 +16,28 @@
|
||||||
"""StridedSliceGrad op"""
|
"""StridedSliceGrad op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \
|
strided_slice_grad_op_info = AiCPURegOp("StridedSliceGrad") \
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
.input(0, "dy", "required") \
|
.input(0, "dy", "required") \
|
||||||
.input(1, "shape", "required") \
|
|
||||||
.input(2, "begin", "required") \
|
|
||||||
.input(3, "end", "required") \
|
|
||||||
.input(4, "stride", "required") \
|
|
||||||
.output(0, "output", "required") \
|
.output(0, "output", "required") \
|
||||||
|
.attr("shapex", "listInt") \
|
||||||
|
.attr("begin", "listInt") \
|
||||||
|
.attr("end", "listInt") \
|
||||||
|
.attr("strides", "listInt") \
|
||||||
.attr("begin_mask", "int") \
|
.attr("begin_mask", "int") \
|
||||||
.attr("end_mask", "int") \
|
.attr("end_mask", "int") \
|
||||||
.attr("ellipsis_mask", "int") \
|
.attr("ellipsis_mask", "int") \
|
||||||
.attr("new_axis_mask", "int") \
|
.attr("new_axis_mask", "int") \
|
||||||
.attr("shrink_axis_mask", "int") \
|
.attr("shrink_axis_mask", "int") \
|
||||||
.dtype_format(DataType.F32_Default,
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
DataType.I32_Default,
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
DataType.I32_Default,
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
DataType.I32_Default,
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
DataType.I32_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(strided_slice_grad_op_info)
|
@op_info_register(strided_slice_grad_op_info)
|
||||||
def _strided_slice_grad_aicpu():
|
def _strided_slice_grad_aicpu():
|
||||||
"""StridedSliceGrad AiCPU register"""
|
"""StridedSliceGrad AiCPU register"""
|
||||||
|
|
|
@ -915,13 +915,14 @@ test_case_math_ops = [
|
||||||
'block': G.MinimumGrad(),
|
'block': G.MinimumGrad(),
|
||||||
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]],
|
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
('StridedSlice', {
|
('StridedSlice_00', {
|
||||||
'block': P.StridedSlice(),
|
'block': P.StridedSlice(shrink_axis_mask=0),
|
||||||
'desc_const': [(0, 1, 2, 1),
|
'desc_const': [(0, 1, 2, 1),
|
||||||
(2, 3, 3, 4),
|
(2, 3, 3, 4),
|
||||||
(1, 1, 1, 1)],
|
(1, 1, 1, 2)],
|
||||||
'desc_inputs': [[2, 3, 3, 5]],
|
'desc_inputs': [[2, 3, 3, 5]],
|
||||||
'desc_bprop': [[2, 2, 1, 3]]}),
|
'desc_bprop': [[2, 2, 1, 3]],
|
||||||
|
'skip': ['backward']}),
|
||||||
('Slice_1', {
|
('Slice_1', {
|
||||||
'block': P.Slice(),
|
'block': P.Slice(),
|
||||||
'desc_const': [(0, 1, 2, 1),
|
'desc_const': [(0, 1, 2, 1),
|
||||||
|
|
Loading…
Reference in New Issue