!34618 StridedSlice dynamic shape support mask & ScatterNdUpdate indices support int64 & Concat Gather support feed mode

Merge pull request !34618 from huoxinyou/0519TensorSLice
This commit is contained in:
i-robot 2022-05-26 09:43:00 +00:00 committed by Gitee
commit d4e6a0877d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 142 additions and 64 deletions

View File

@ -29,10 +29,16 @@
namespace mindspore {
namespace kernel {
void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *input_size_list, size_t *size_i) {
if (input_json[kJShape].size() == 1 && input_json[kJShape][0] == -2) {
size_t kMaxShapeIdx = 1;
int64_t kDynShapeValue = -2;
if (input_json[kJShape].size() == 1 && input_json[kJShape][0] == kDynShapeValue) {
auto input_max_shape = input_json[kJRange];
for (auto &max_shape : input_max_shape) {
(*size_i) = SizetMulWithOverflowCheck((*size_i), LongToSize(max_shape[1]));
if (max_shape[kMaxShapeIdx] < 0) {
(*size_i) = SizetMulWithOverflowCheck((*size_i), 0);
} else {
(*size_i) = SizetMulWithOverflowCheck((*size_i), LongToSize(max_shape[kMaxShapeIdx]));
}
}
MS_LOG(INFO) << "Dims is dynamic, change -2 Shape to Max Shape.";
} else {
@ -42,7 +48,6 @@ void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *inp
if (j >= input_max_shape.size()) {
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
}
size_t kMaxShapeIdx = 1;
if (input_max_shape[j][kMaxShapeIdx] == -1) {
MS_LOG(INFO) << "Change -1 Shape to 1";
(*size_i) = SizetMulWithOverflowCheck((*size_i), 1);
@ -82,10 +87,16 @@ void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *inp
}
void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *output_size_list, size_t *size_i) {
if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) {
size_t kMaxShapeIdx = 1;
int64_t kDynShapeValue = -2;
if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == kDynShapeValue) {
auto output_max_shape = output_json[kJRange];
for (auto &max_shape : output_max_shape) {
(*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(max_shape[1]));
if (max_shape[kMaxShapeIdx] < 0) {
(*size_i) = SizetMulWithOverflowCheck((*size_i), 0);
} else {
(*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(max_shape[kMaxShapeIdx]));
}
}
MS_LOG(INFO) << "Dims is dynamic, change -2 Shape to Max Shape.";
} else {
@ -95,7 +106,6 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *o
if (j >= output_max_shape.size()) {
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
}
size_t kMaxShapeIdx = 1;
if (output_max_shape[j][kMaxShapeIdx] == -1) {
MS_LOG(INFO) << "Change -1 Shape to 1";
(*size_i) = SizetMulWithOverflowCheck((*size_i), 1);

View File

@ -70,11 +70,17 @@ abstract::ShapePtr ConcatInferShape(const PrimitivePtr &primitive, const std::ve
if (x_shape_ptr->IsDynamic()) {
auto element0_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape];
auto element0_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape];
if (element0_max_shape.empty() || element0_min_shape.empty()) {
return std::make_shared<abstract::Shape>(ret_shape);
}
auto ret_max_shape = element0_max_shape;
auto ret_min_shape = element0_min_shape;
for (size_t i = 1; i < elements.size(); ++i) {
auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kMaxShape];
auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kMinShape];
if (elementi_max_shape.empty() || elementi_min_shape.empty()) {
return std::make_shared<abstract::Shape>(ret_shape);
}
ret_max_shape[axis] += elementi_max_shape[axis];
ret_min_shape[axis] += elementi_min_shape[axis];
}

View File

@ -49,18 +49,13 @@ abstract::ShapePtr DynamicBroadcastToInferShape(const PrimitivePtr &primitive,
std::vector<int64_t> real_shape;
std::vector<int64_t> max_shape;
std::vector<int64_t> min_shape;
if (y_shape->IsDynamic()) {
auto min_value = input_y->cast<abstract::AbstractTensorPtr>()->get_min_value();
auto max_value = input_y->cast<abstract::AbstractTensorPtr>()->get_max_value();
if (y_shape->IsDynamic() || !min_value || !max_value) {
// max shape unknown
output_shape.push_back(-2);
} else {
auto out_dims = LongToSize(y_shape->shape()[0]);
auto min_value = input_y->cast<abstract::AbstractTensorPtr>()->get_min_value();
auto max_value = input_y->cast<abstract::AbstractTensorPtr>()->get_max_value();
if (!min_value || !max_value) {
MS_EXCEPTION(ValueError)
<< "For 'BroadcastTo', inputs['shape'] min or max value can not be empty. But got min: " << min_value
<< "max: " << max_value << ".";
}
min_shape = GetValue<std::vector<int64_t>>(min_value);
max_shape = GetValue<std::vector<int64_t>>(max_value);
if (min_shape.size() != out_dims || max_shape.size() != out_dims) {

View File

@ -34,8 +34,12 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
abstract::AbstractTensorPtr params =
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
bool ind_has_m_shape = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
bool param_has_m_shape = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
bool ind_dyn =
std::any_of(indices->shape()->shape().begin(), indices->shape()->shape().end(), [](int64_t s) { return s < 0; });
bool param_dyn =
std::any_of(params->shape()->shape().begin(), params->shape()->shape().end(), [](int64_t s) { return s < 0; });
int64_t axis_val = 0;
// 3rd input is a Tensor when Gather is a dynamic shape operator
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
@ -59,10 +63,10 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve
auto params_rank = static_cast<int64_t>(params_shp.size());
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis_val, kIncludeLeft, {-params_rank, params_rank}, op_name);
// either inputs or both can be dynamic and computation requires min/max shapes for both
ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape();
ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape();
ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape();
ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape();
ShapeVector param_shp_min = (param_has_m_shape) ? params->shape()->min_shape() : params->shape()->shape();
ShapeVector param_shp_max = (param_has_m_shape) ? params->shape()->max_shape() : params->shape()->shape();
ShapeVector indices_shp_min = (ind_has_m_shape) ? indices->shape()->min_shape() : indices->shape()->shape();
ShapeVector indices_shp_max = (ind_has_m_shape) ? indices->shape()->max_shape() : indices->shape()->shape();
// check axis_val within interval: [0, params_rank)
if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) {
MS_LOG(EXCEPTION) << "For 'Gather', axis value must be within range [" << -params_rank << ", " << params_rank
@ -79,7 +83,7 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve
return out_vec;
};
ShapeVector out_shape = calc_shape(indices_shp, params_shp);
if (ind_dyn || param_dyn) {
if ((ind_dyn || param_dyn) && ind_dyn == ind_has_m_shape && param_dyn == param_has_m_shape) {
ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min);
ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max);
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);

View File

@ -83,7 +83,7 @@ TypePtr ScatterNdArithmeticInferType(const PrimitivePtr &primitive, const std::v
auto input_x_dtype = input_args[kInputIndex0]->BuildType();
auto indices_dtype = input_args[kInputIndex1]->BuildType();
auto updates_dtype = input_args[kInputIndex2]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_dtype, {kInt32}, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_dtype, {kInt32, kInt64}, prim_name);
std::map<std::string, TypePtr> type_dict = {{"input_x type", input_x_dtype}, {"updates type", updates_dtype}};
// Only ScatterNdUpdate supports boolean type
if (prim_name == prim::kPrimScatterNdUpdate->name()) {

View File

@ -255,9 +255,8 @@ std::vector<int64_t> ComputeInferShape(const PrimitivePtr &primitive, const std:
return infer_shape;
}
std::vector<int64_t> DynamicComputeInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &begin_v,
const std::vector<int64_t> &end_v, const std::vector<int64_t> &strides_v,
const std::vector<int64_t> &x_shape, const size_t slice_len) {
ShapeMap DynamicComputeInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
const size_t slice_len, const std::vector<int64_t> &max_shape) {
// currently not support mask
std::vector<int64_t> begin_pos;
std::vector<int64_t> end_pos;
@ -271,7 +270,10 @@ std::vector<int64_t> DynamicComputeInferShape(const PrimitivePtr &primitive, con
int64_t start;
int64_t finish;
int64_t strides;
ShapeMap shape_map;
std::vector<int64_t> infer_shape;
std::vector<int64_t> infer_min_shape;
std::vector<int64_t> infer_max_shape;
size_t x_rank = x_shape.size();
while (i < x_rank || j < slice_len) {
int64_t slicing_length = -1;
@ -279,6 +281,20 @@ std::vector<int64_t> DynamicComputeInferShape(const PrimitivePtr &primitive, con
if (x_dim_size == 1) {
slicing_length = 1;
}
if (j < slice_len) {
if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
infer_shape.push_back(1);
infer_min_shape.push_back(1);
infer_max_shape.push_back(1);
j += 1;
continue;
}
if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
j += 1;
i += 1;
continue;
}
}
if (j >= slice_len && x_dim_size > 0) {
start = 0;
finish = x_shape[i];
@ -288,10 +304,17 @@ std::vector<int64_t> DynamicComputeInferShape(const PrimitivePtr &primitive, con
}
}
infer_shape.push_back(slicing_length);
if (max_shape.size() != 0) {
infer_min_shape.push_back(1);
infer_max_shape.push_back(max_shape[i]);
}
i += 1;
j += 1;
}
return infer_shape;
shape_map[kShape] = infer_shape;
shape_map[kMinShape] = infer_min_shape;
shape_map[kMaxShape] = infer_max_shape;
return shape_map;
}
bool CheckAndGetDynamicSlice(const AbstractBasePtr &input_arg, const std::string &arg_name, ShapeVector *slice_value,
@ -368,29 +391,17 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
}
if (!slice_dynamic) {
ret_in_shape = ComputeInferShape(primitive, begin_v, end_v, strides_v, x_shape);
bool has_zero_shape = std::any_of(ret_in_shape.begin(), ret_in_shape.end(), [](int64_t i) { return i == 0; });
if (has_zero_shape) {
MS_LOG(EXCEPTION) << "'StridedSlice' doesn't support zero shape, but got out shape: " << ret_in_shape << ".";
}
return std::make_shared<abstract::Shape>(ret_in_shape);
}
ret_in_shape = DynamicComputeInferShape(primitive, begin_v, end_v, strides_v, x_shape, begin_len);
auto ret_shape_map = DynamicComputeInferShape(primitive, x_shape, begin_len, max_shape);
ret_in_shape = ret_shape_map[kShape];
auto ret_min_shape = ret_shape_map[kMinShape];
auto ret_max_shape = ret_shape_map[kMaxShape];
if (x_is_dyn && (max_shape.empty() || min_shape.empty())) {
return std::make_shared<abstract::Shape>(ret_in_shape);
}
ShapeVector ret_min_shape(x_shape.size(), 1);
ShapeVector ret_max_shape = x_shape;
for (size_t i = 0; i < ret_in_shape.size(); i++) {
if (ret_in_shape[i] > 0) {
ret_min_shape[i] = ret_in_shape[i];
ret_max_shape[i] = ret_in_shape[i];
} else {
ret_min_shape[i] = min_shape[i];
ret_max_shape[i] = max_shape[i];
}
}
return std::make_shared<abstract::Shape>(ret_in_shape, ret_min_shape, ret_max_shape);
}

View File

@ -35,6 +35,25 @@
namespace mindspore {
namespace ops {
namespace {
bool CheckShape(const std::vector<int64_t> &updates_shape, const std::vector<int64_t> &check_shape) {
if (std::find(updates_shape.begin(), updates_shape.end(), -2) != updates_shape.end() ||
std::find(check_shape.begin(), check_shape.end(), -2) != check_shape.end()) {
return true;
}
if (updates_shape.size() != check_shape.size()) {
return false;
}
for (size_t i = 0; i < updates_shape.size(); ++i) {
if (updates_shape[i] == -1 || check_shape[i] == -1) {
continue;
}
if (updates_shape[i] != check_shape[i]) {
return false;
}
}
return true;
}
abstract::ShapePtr TensorScatterArithmeticInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
@ -64,7 +83,7 @@ abstract::ShapePtr TensorScatterArithmeticInferShape(const PrimitivePtr &primiti
<< " and the dimension of 'input_x': " << input_x_shape.size();
}
indices_shape.insert(indices_shape.end(), input_x_shape.begin() + last_dim, input_x_shape.end());
if (updates_shape != indices_shape) {
if (CheckShape(updates_shape, indices_shape) == false) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
<< "updates_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: "
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()

View File

@ -43,7 +43,9 @@ abstract::AbstractBasePtr TensorShapeInfer(const abstract::AnalysisEnginePtr &,
auto min_value = MakeValue(input->shape()->min_shape());
auto max_value = MakeValue(input->shape()->max_shape());
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(elem, std::make_shared<abstract::Shape>(tensor_shp));
abs_tensor->set_value_range(min_value, max_value);
if (!input->shape()->min_shape().empty() && !input->shape()->max_shape().empty()) {
abs_tensor->set_value_range(min_value, max_value);
}
return abs_tensor;
}
auto shp_buf_size = sizeof(int64_t) * shape.size();

View File

@ -33,6 +33,11 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()

View File

@ -34,6 +34,11 @@ scatter_nd_update_ds_op_info = TBERegOp("ScatterNdUpdate") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()

View File

@ -2356,12 +2356,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
output_max_shape = list(num_segments['max_value'])
output_min_shape = list(num_segments['min_value'])
else:
if isinstance(num_segments_type, type(mstype.tensor)):
raise ValueError(f"For '{self.name}', the dtype of 'num_segments' only support int type "
f"when it is not a dynamic value, but got type of 'num_segments': "
f"{num_segments_type}.")
output_max_shape = [num_segments_v]
output_min_shape = [num_segments_v]
if num_segments_v is None:
output_max_shape = []
output_min_shape = []
if 'max_shape' in x and 'min_shape' in x:
max_output_incoming = x['max_shape']
min_output_incoming = x['min_shape']
@ -3671,14 +3670,20 @@ class StridedSlice(PrimitiveWithInfer):
'min_shape': tuple(ret_min_shape)}
if None in (begin_v['value'], end_v['value'], strides_v['value']) or (-1 in x_shape):
ret_shape = self._compute_dynamic_slicing_shape(x_shape, begin_len)
ret_shape, ret_min_shape, ret_max_shape = \
self._compute_dynamic_slicing_shape(x_shape, begin_len, max_shape)
rets = {'shape': ret_shape,
'dtype': x['dtype'],
'value': None}
if -1 in x_shape and (max_shape is None or min_shape is None):
return rets
return self._compute_max_min_shape(rets, x_shape, max_shape, min_shape, ret_shape)
if max_shape is not None and min_shape is not None:
rets['min_shape'] = ret_min_shape
rets['max_shape'] = ret_max_shape
if -1 not in x_shape:
return self._compute_max_min_shape(rets, x_shape, ret_shape)
return rets
ret_shape = self._compute_slicing_shape(x_shape, begin_v['value'], end_v['value'], strides_v['value'])
if all(ret_shape):
@ -3712,7 +3717,7 @@ class StridedSlice(PrimitiveWithInfer):
'dtype': x['dtype'],
'value': value}
def _compute_max_min_shape(self, rets, x_shape, max_shape, min_shape, ret_shape):
def _compute_max_min_shape(self, rets, x_shape, ret_shape):
"""compute max/min shape"""
ret_min_shape = [1] * len(x_shape)
ret_max_shape = x_shape
@ -3720,12 +3725,8 @@ class StridedSlice(PrimitiveWithInfer):
if val > 0:
ret_min_shape[i] = val
ret_max_shape[i] = val
elif -1 in x_shape:
ret_min_shape[i] = min_shape[i]
ret_max_shape[i] = max_shape[i]
rets['max_shape'] = ret_max_shape
rets['min_shape'] = ret_min_shape
return rets
def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v):
@ -3812,17 +3813,31 @@ class StridedSlice(PrimitiveWithInfer):
j += 1
return ret_shape
def _compute_dynamic_slicing_shape(self, x_shape, slice_len):
def _compute_dynamic_slicing_shape(self, x_shape, slice_len, max_shape):
"""Computes the shape of the slicing for dynamic shape, mask is currently not supported."""
x_rank = len(x_shape)
if self.begin_mask != 0 or self.end_mask != 0 or self.ellipsis_mask or self.new_axis_mask != 0 \
or self.shrink_axis_mask != 0:
raise ValueError("Mask is currently not supported if 'begin', 'end' or 'strides' is not a constant.")
new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]
if self.ellipsis_mask:
raise ValueError("Ellipsis Mask is currently not supported.")
ret_shape = []
ret_min_shape = []
ret_max_shape = []
i, j = 0, 0
while i < x_rank or j < slice_len:
slicing_length = -1 if x_shape[i] != 1 else 1
if j >= slice_len:
if j < slice_len:
if j < len(new_axis_pos) and new_axis_pos[j] == '1':
ret_shape.append(1)
ret_min_shape.append(1)
ret_max_shape.append(1)
j += 1
continue
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
j += 1
i += 1
continue
else:
if i >= len(x_shape):
raise ValueError(f"For 'StridedSlice', the index must be less than or equal to "
f"the dimension of 'input_x', but got the dimension of 'input_x': {len(x_shape)} "
@ -3831,9 +3846,12 @@ class StridedSlice(PrimitiveWithInfer):
if end > 0:
slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
ret_shape.append(slicing_length)
if max_shape is not None:
ret_min_shape.append(1)
ret_max_shape.append(max_shape[i])
i += 1
j += 1
return ret_shape
return ret_shape, ret_min_shape, ret_max_shape
class Diag(PrimitiveWithInfer):
@ -6673,6 +6691,9 @@ class _TensorScatterOp(PrimitiveWithInfer):
return input_x_dtype
def _check_shape(self, expect, real):
"""check shape"""
if -2 in expect or -2 in real:
return True
if len(expect) != len(real):
return False
for a, b in zip(expect, real):