forked from mindspore-Ecosystem/mindspore
!34618 StridedSlice dynamic shape support mask & ScatterNdUpdate indices support int64 & Concat Gather support feed mode
Merge pull request !34618 from huoxinyou/0519TensorSLice
This commit is contained in:
commit
d4e6a0877d
|
@ -29,10 +29,16 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *input_size_list, size_t *size_i) {
|
||||
if (input_json[kJShape].size() == 1 && input_json[kJShape][0] == -2) {
|
||||
size_t kMaxShapeIdx = 1;
|
||||
int64_t kDynShapeValue = -2;
|
||||
if (input_json[kJShape].size() == 1 && input_json[kJShape][0] == kDynShapeValue) {
|
||||
auto input_max_shape = input_json[kJRange];
|
||||
for (auto &max_shape : input_max_shape) {
|
||||
(*size_i) = SizetMulWithOverflowCheck((*size_i), LongToSize(max_shape[1]));
|
||||
if (max_shape[kMaxShapeIdx] < 0) {
|
||||
(*size_i) = SizetMulWithOverflowCheck((*size_i), 0);
|
||||
} else {
|
||||
(*size_i) = SizetMulWithOverflowCheck((*size_i), LongToSize(max_shape[kMaxShapeIdx]));
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Dims is dynamic, change -2 Shape to Max Shape.";
|
||||
} else {
|
||||
|
@ -42,7 +48,6 @@ void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *inp
|
|||
if (j >= input_max_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
|
||||
}
|
||||
size_t kMaxShapeIdx = 1;
|
||||
if (input_max_shape[j][kMaxShapeIdx] == -1) {
|
||||
MS_LOG(INFO) << "Change -1 Shape to 1";
|
||||
(*size_i) = SizetMulWithOverflowCheck((*size_i), 1);
|
||||
|
@ -82,10 +87,16 @@ void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *inp
|
|||
}
|
||||
|
||||
void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *output_size_list, size_t *size_i) {
|
||||
if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) {
|
||||
size_t kMaxShapeIdx = 1;
|
||||
int64_t kDynShapeValue = -2;
|
||||
if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == kDynShapeValue) {
|
||||
auto output_max_shape = output_json[kJRange];
|
||||
for (auto &max_shape : output_max_shape) {
|
||||
(*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(max_shape[1]));
|
||||
if (max_shape[kMaxShapeIdx] < 0) {
|
||||
(*size_i) = SizetMulWithOverflowCheck((*size_i), 0);
|
||||
} else {
|
||||
(*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(max_shape[kMaxShapeIdx]));
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Dims is dynamic, change -2 Shape to Max Shape.";
|
||||
} else {
|
||||
|
@ -95,7 +106,6 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *o
|
|||
if (j >= output_max_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
|
||||
}
|
||||
size_t kMaxShapeIdx = 1;
|
||||
if (output_max_shape[j][kMaxShapeIdx] == -1) {
|
||||
MS_LOG(INFO) << "Change -1 Shape to 1";
|
||||
(*size_i) = SizetMulWithOverflowCheck((*size_i), 1);
|
||||
|
|
|
@ -70,11 +70,17 @@ abstract::ShapePtr ConcatInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
if (x_shape_ptr->IsDynamic()) {
|
||||
auto element0_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape];
|
||||
auto element0_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape];
|
||||
if (element0_max_shape.empty() || element0_min_shape.empty()) {
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
auto ret_max_shape = element0_max_shape;
|
||||
auto ret_min_shape = element0_min_shape;
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kMaxShape];
|
||||
auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kMinShape];
|
||||
if (elementi_max_shape.empty() || elementi_min_shape.empty()) {
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
ret_max_shape[axis] += elementi_max_shape[axis];
|
||||
ret_min_shape[axis] += elementi_min_shape[axis];
|
||||
}
|
||||
|
|
|
@ -49,18 +49,13 @@ abstract::ShapePtr DynamicBroadcastToInferShape(const PrimitivePtr &primitive,
|
|||
std::vector<int64_t> real_shape;
|
||||
std::vector<int64_t> max_shape;
|
||||
std::vector<int64_t> min_shape;
|
||||
if (y_shape->IsDynamic()) {
|
||||
auto min_value = input_y->cast<abstract::AbstractTensorPtr>()->get_min_value();
|
||||
auto max_value = input_y->cast<abstract::AbstractTensorPtr>()->get_max_value();
|
||||
if (y_shape->IsDynamic() || !min_value || !max_value) {
|
||||
// max shape unknown
|
||||
output_shape.push_back(-2);
|
||||
} else {
|
||||
auto out_dims = LongToSize(y_shape->shape()[0]);
|
||||
auto min_value = input_y->cast<abstract::AbstractTensorPtr>()->get_min_value();
|
||||
auto max_value = input_y->cast<abstract::AbstractTensorPtr>()->get_max_value();
|
||||
if (!min_value || !max_value) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'BroadcastTo', inputs['shape'] min or max value can not be empty. But got min: " << min_value
|
||||
<< "max: " << max_value << ".";
|
||||
}
|
||||
min_shape = GetValue<std::vector<int64_t>>(min_value);
|
||||
max_shape = GetValue<std::vector<int64_t>>(max_value);
|
||||
if (min_shape.size() != out_dims || max_shape.size() != out_dims) {
|
||||
|
|
|
@ -34,8 +34,12 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
|
||||
abstract::AbstractTensorPtr params =
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
|
||||
bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
|
||||
bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
|
||||
bool ind_has_m_shape = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
|
||||
bool param_has_m_shape = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
|
||||
bool ind_dyn =
|
||||
std::any_of(indices->shape()->shape().begin(), indices->shape()->shape().end(), [](int64_t s) { return s < 0; });
|
||||
bool param_dyn =
|
||||
std::any_of(params->shape()->shape().begin(), params->shape()->shape().end(), [](int64_t s) { return s < 0; });
|
||||
int64_t axis_val = 0;
|
||||
// 3rd input is a Tensor when Gather is a dynamic shape operator
|
||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||
|
@ -59,10 +63,10 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
auto params_rank = static_cast<int64_t>(params_shp.size());
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis_val, kIncludeLeft, {-params_rank, params_rank}, op_name);
|
||||
// either inputs or both can be dynamic and computation requires min/max shapes for both
|
||||
ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape();
|
||||
ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape();
|
||||
ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape();
|
||||
ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape();
|
||||
ShapeVector param_shp_min = (param_has_m_shape) ? params->shape()->min_shape() : params->shape()->shape();
|
||||
ShapeVector param_shp_max = (param_has_m_shape) ? params->shape()->max_shape() : params->shape()->shape();
|
||||
ShapeVector indices_shp_min = (ind_has_m_shape) ? indices->shape()->min_shape() : indices->shape()->shape();
|
||||
ShapeVector indices_shp_max = (ind_has_m_shape) ? indices->shape()->max_shape() : indices->shape()->shape();
|
||||
// check axis_val within interval: [0, params_rank)
|
||||
if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) {
|
||||
MS_LOG(EXCEPTION) << "For 'Gather', axis value must be within range [" << -params_rank << ", " << params_rank
|
||||
|
@ -79,7 +83,7 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
return out_vec;
|
||||
};
|
||||
ShapeVector out_shape = calc_shape(indices_shp, params_shp);
|
||||
if (ind_dyn || param_dyn) {
|
||||
if ((ind_dyn || param_dyn) && ind_dyn == ind_has_m_shape && param_dyn == param_has_m_shape) {
|
||||
ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min);
|
||||
ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max);
|
||||
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);
|
||||
|
|
|
@ -83,7 +83,7 @@ TypePtr ScatterNdArithmeticInferType(const PrimitivePtr &primitive, const std::v
|
|||
auto input_x_dtype = input_args[kInputIndex0]->BuildType();
|
||||
auto indices_dtype = input_args[kInputIndex1]->BuildType();
|
||||
auto updates_dtype = input_args[kInputIndex2]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_dtype, {kInt32}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_dtype, {kInt32, kInt64}, prim_name);
|
||||
std::map<std::string, TypePtr> type_dict = {{"input_x type", input_x_dtype}, {"updates type", updates_dtype}};
|
||||
// Only ScatterNdUpdate supports boolean type
|
||||
if (prim_name == prim::kPrimScatterNdUpdate->name()) {
|
||||
|
|
|
@ -255,9 +255,8 @@ std::vector<int64_t> ComputeInferShape(const PrimitivePtr &primitive, const std:
|
|||
return infer_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> DynamicComputeInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &begin_v,
|
||||
const std::vector<int64_t> &end_v, const std::vector<int64_t> &strides_v,
|
||||
const std::vector<int64_t> &x_shape, const size_t slice_len) {
|
||||
ShapeMap DynamicComputeInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
|
||||
const size_t slice_len, const std::vector<int64_t> &max_shape) {
|
||||
// currently not support mask
|
||||
std::vector<int64_t> begin_pos;
|
||||
std::vector<int64_t> end_pos;
|
||||
|
@ -271,7 +270,10 @@ std::vector<int64_t> DynamicComputeInferShape(const PrimitivePtr &primitive, con
|
|||
int64_t start;
|
||||
int64_t finish;
|
||||
int64_t strides;
|
||||
ShapeMap shape_map;
|
||||
std::vector<int64_t> infer_shape;
|
||||
std::vector<int64_t> infer_min_shape;
|
||||
std::vector<int64_t> infer_max_shape;
|
||||
size_t x_rank = x_shape.size();
|
||||
while (i < x_rank || j < slice_len) {
|
||||
int64_t slicing_length = -1;
|
||||
|
@ -279,6 +281,20 @@ std::vector<int64_t> DynamicComputeInferShape(const PrimitivePtr &primitive, con
|
|||
if (x_dim_size == 1) {
|
||||
slicing_length = 1;
|
||||
}
|
||||
if (j < slice_len) {
|
||||
if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
|
||||
infer_shape.push_back(1);
|
||||
infer_min_shape.push_back(1);
|
||||
infer_max_shape.push_back(1);
|
||||
j += 1;
|
||||
continue;
|
||||
}
|
||||
if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
|
||||
j += 1;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (j >= slice_len && x_dim_size > 0) {
|
||||
start = 0;
|
||||
finish = x_shape[i];
|
||||
|
@ -288,10 +304,17 @@ std::vector<int64_t> DynamicComputeInferShape(const PrimitivePtr &primitive, con
|
|||
}
|
||||
}
|
||||
infer_shape.push_back(slicing_length);
|
||||
if (max_shape.size() != 0) {
|
||||
infer_min_shape.push_back(1);
|
||||
infer_max_shape.push_back(max_shape[i]);
|
||||
}
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
return infer_shape;
|
||||
shape_map[kShape] = infer_shape;
|
||||
shape_map[kMinShape] = infer_min_shape;
|
||||
shape_map[kMaxShape] = infer_max_shape;
|
||||
return shape_map;
|
||||
}
|
||||
|
||||
bool CheckAndGetDynamicSlice(const AbstractBasePtr &input_arg, const std::string &arg_name, ShapeVector *slice_value,
|
||||
|
@ -368,29 +391,17 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
if (!slice_dynamic) {
|
||||
ret_in_shape = ComputeInferShape(primitive, begin_v, end_v, strides_v, x_shape);
|
||||
bool has_zero_shape = std::any_of(ret_in_shape.begin(), ret_in_shape.end(), [](int64_t i) { return i == 0; });
|
||||
if (has_zero_shape) {
|
||||
MS_LOG(EXCEPTION) << "'StridedSlice' doesn't support zero shape, but got out shape: " << ret_in_shape << ".";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(ret_in_shape);
|
||||
}
|
||||
ret_in_shape = DynamicComputeInferShape(primitive, begin_v, end_v, strides_v, x_shape, begin_len);
|
||||
auto ret_shape_map = DynamicComputeInferShape(primitive, x_shape, begin_len, max_shape);
|
||||
ret_in_shape = ret_shape_map[kShape];
|
||||
auto ret_min_shape = ret_shape_map[kMinShape];
|
||||
auto ret_max_shape = ret_shape_map[kMaxShape];
|
||||
|
||||
if (x_is_dyn && (max_shape.empty() || min_shape.empty())) {
|
||||
return std::make_shared<abstract::Shape>(ret_in_shape);
|
||||
}
|
||||
|
||||
ShapeVector ret_min_shape(x_shape.size(), 1);
|
||||
ShapeVector ret_max_shape = x_shape;
|
||||
for (size_t i = 0; i < ret_in_shape.size(); i++) {
|
||||
if (ret_in_shape[i] > 0) {
|
||||
ret_min_shape[i] = ret_in_shape[i];
|
||||
ret_max_shape[i] = ret_in_shape[i];
|
||||
} else {
|
||||
ret_min_shape[i] = min_shape[i];
|
||||
ret_max_shape[i] = max_shape[i];
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(ret_in_shape, ret_min_shape, ret_max_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,25 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
bool CheckShape(const std::vector<int64_t> &updates_shape, const std::vector<int64_t> &check_shape) {
|
||||
if (std::find(updates_shape.begin(), updates_shape.end(), -2) != updates_shape.end() ||
|
||||
std::find(check_shape.begin(), check_shape.end(), -2) != check_shape.end()) {
|
||||
return true;
|
||||
}
|
||||
if (updates_shape.size() != check_shape.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < updates_shape.size(); ++i) {
|
||||
if (updates_shape[i] == -1 || check_shape[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
if (updates_shape[i] != check_shape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
abstract::ShapePtr TensorScatterArithmeticInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
|
@ -64,7 +83,7 @@ abstract::ShapePtr TensorScatterArithmeticInferShape(const PrimitivePtr &primiti
|
|||
<< " and the dimension of 'input_x': " << input_x_shape.size();
|
||||
}
|
||||
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
|
||||
if (updates_shape != indices_shape) {
|
||||
if (CheckShape(updates_shape, indices_shape) == false) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
|
||||
<< "updates_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: "
|
||||
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
|
||||
|
|
|
@ -43,7 +43,9 @@ abstract::AbstractBasePtr TensorShapeInfer(const abstract::AnalysisEnginePtr &,
|
|||
auto min_value = MakeValue(input->shape()->min_shape());
|
||||
auto max_value = MakeValue(input->shape()->max_shape());
|
||||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(elem, std::make_shared<abstract::Shape>(tensor_shp));
|
||||
abs_tensor->set_value_range(min_value, max_value);
|
||||
if (!input->shape()->min_shape().empty() && !input->shape()->max_shape().empty()) {
|
||||
abs_tensor->set_value_range(min_value, max_value);
|
||||
}
|
||||
return abs_tensor;
|
||||
}
|
||||
auto shp_buf_size = sizeof(int64_t) * shape.size();
|
||||
|
|
|
@ -33,6 +33,11 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
|
|||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -34,6 +34,11 @@ scatter_nd_update_ds_op_info = TBERegOp("ScatterNdUpdate") \
|
|||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -2356,12 +2356,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
output_max_shape = list(num_segments['max_value'])
|
||||
output_min_shape = list(num_segments['min_value'])
|
||||
else:
|
||||
if isinstance(num_segments_type, type(mstype.tensor)):
|
||||
raise ValueError(f"For '{self.name}', the dtype of 'num_segments' only support int type "
|
||||
f"when it is not a dynamic value, but got type of 'num_segments': "
|
||||
f"{num_segments_type}.")
|
||||
output_max_shape = [num_segments_v]
|
||||
output_min_shape = [num_segments_v]
|
||||
if num_segments_v is None:
|
||||
output_max_shape = []
|
||||
output_min_shape = []
|
||||
if 'max_shape' in x and 'min_shape' in x:
|
||||
max_output_incoming = x['max_shape']
|
||||
min_output_incoming = x['min_shape']
|
||||
|
@ -3671,14 +3670,20 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
'min_shape': tuple(ret_min_shape)}
|
||||
|
||||
if None in (begin_v['value'], end_v['value'], strides_v['value']) or (-1 in x_shape):
|
||||
ret_shape = self._compute_dynamic_slicing_shape(x_shape, begin_len)
|
||||
|
||||
ret_shape, ret_min_shape, ret_max_shape = \
|
||||
self._compute_dynamic_slicing_shape(x_shape, begin_len, max_shape)
|
||||
rets = {'shape': ret_shape,
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
if -1 in x_shape and (max_shape is None or min_shape is None):
|
||||
return rets
|
||||
return self._compute_max_min_shape(rets, x_shape, max_shape, min_shape, ret_shape)
|
||||
|
||||
if max_shape is not None and min_shape is not None:
|
||||
rets['min_shape'] = ret_min_shape
|
||||
rets['max_shape'] = ret_max_shape
|
||||
|
||||
if -1 not in x_shape:
|
||||
return self._compute_max_min_shape(rets, x_shape, ret_shape)
|
||||
|
||||
return rets
|
||||
|
||||
ret_shape = self._compute_slicing_shape(x_shape, begin_v['value'], end_v['value'], strides_v['value'])
|
||||
if all(ret_shape):
|
||||
|
@ -3712,7 +3717,7 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
'dtype': x['dtype'],
|
||||
'value': value}
|
||||
|
||||
def _compute_max_min_shape(self, rets, x_shape, max_shape, min_shape, ret_shape):
|
||||
def _compute_max_min_shape(self, rets, x_shape, ret_shape):
|
||||
"""compute max/min shape"""
|
||||
ret_min_shape = [1] * len(x_shape)
|
||||
ret_max_shape = x_shape
|
||||
|
@ -3720,12 +3725,8 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
if val > 0:
|
||||
ret_min_shape[i] = val
|
||||
ret_max_shape[i] = val
|
||||
elif -1 in x_shape:
|
||||
ret_min_shape[i] = min_shape[i]
|
||||
ret_max_shape[i] = max_shape[i]
|
||||
rets['max_shape'] = ret_max_shape
|
||||
rets['min_shape'] = ret_min_shape
|
||||
|
||||
return rets
|
||||
|
||||
def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v):
|
||||
|
@ -3812,17 +3813,31 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
j += 1
|
||||
return ret_shape
|
||||
|
||||
def _compute_dynamic_slicing_shape(self, x_shape, slice_len):
|
||||
def _compute_dynamic_slicing_shape(self, x_shape, slice_len, max_shape):
|
||||
"""Computes the shape of the slicing for dynamic shape, mask is currently not supported."""
|
||||
x_rank = len(x_shape)
|
||||
if self.begin_mask != 0 or self.end_mask != 0 or self.ellipsis_mask or self.new_axis_mask != 0 \
|
||||
or self.shrink_axis_mask != 0:
|
||||
raise ValueError("Mask is currently not supported if 'begin', 'end' or 'strides' is not a constant.")
|
||||
new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
|
||||
shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]
|
||||
if self.ellipsis_mask:
|
||||
raise ValueError("Ellipsis Mask is currently not supported.")
|
||||
ret_shape = []
|
||||
ret_min_shape = []
|
||||
ret_max_shape = []
|
||||
i, j = 0, 0
|
||||
while i < x_rank or j < slice_len:
|
||||
slicing_length = -1 if x_shape[i] != 1 else 1
|
||||
if j >= slice_len:
|
||||
if j < slice_len:
|
||||
if j < len(new_axis_pos) and new_axis_pos[j] == '1':
|
||||
ret_shape.append(1)
|
||||
ret_min_shape.append(1)
|
||||
ret_max_shape.append(1)
|
||||
j += 1
|
||||
continue
|
||||
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
|
||||
j += 1
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
if i >= len(x_shape):
|
||||
raise ValueError(f"For 'StridedSlice', the index must be less than or equal to "
|
||||
f"the dimension of 'input_x', but got the dimension of 'input_x': {len(x_shape)} "
|
||||
|
@ -3831,9 +3846,12 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
if end > 0:
|
||||
slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
|
||||
ret_shape.append(slicing_length)
|
||||
if max_shape is not None:
|
||||
ret_min_shape.append(1)
|
||||
ret_max_shape.append(max_shape[i])
|
||||
i += 1
|
||||
j += 1
|
||||
return ret_shape
|
||||
return ret_shape, ret_min_shape, ret_max_shape
|
||||
|
||||
|
||||
class Diag(PrimitiveWithInfer):
|
||||
|
@ -6673,6 +6691,9 @@ class _TensorScatterOp(PrimitiveWithInfer):
|
|||
return input_x_dtype
|
||||
|
||||
def _check_shape(self, expect, real):
|
||||
"""check shape"""
|
||||
if -2 in expect or -2 in real:
|
||||
return True
|
||||
if len(expect) != len(real):
|
||||
return False
|
||||
for a, b in zip(expect, real):
|
||||
|
|
Loading…
Reference in New Issue