support DynamicShape and StridedSlice primitive return range of shape value.

This commit is contained in:
mxm 2020-12-02 10:30:07 +08:00
parent d6dbfd6593
commit 82268dc703
9 changed files with 143 additions and 33 deletions

View File

@ -273,10 +273,14 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
size_t len = arg_tuple->size();
py::tuple shape_tuple(len);
py::tuple dtype_tuple(len);
py::tuple value_tuple(len);
py::tuple min_value_tuple(len);
py::tuple max_value_tuple(len);
py::tuple min_shape_tuple(len);
py::tuple max_shape_tuple(len);
auto dic = py::dict();
bool dyn_shape = false;
bool is_build_value = true;
for (size_t i = 0; i < len; i++) {
auto arg = arg_tuple->elements()[i];
@ -284,6 +288,18 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
shape_tuple[i] = out[ATTR_SHAPE];
dtype_tuple[i] = out[ATTR_DTYPE];
// Elements in tuple is tensor shape value.
if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) {
value_tuple[i] = out[ATTR_VALUE];
min_value_tuple[i] = out[ATTR_MIN_VALUE];
max_value_tuple[i] = out[ATTR_MAX_VALUE];
is_build_value = false;
} else {
value_tuple[i] = BuildValue(arg->BuildValue());
min_value_tuple[i] = value_tuple[i];
max_value_tuple[i] = value_tuple[i];
}
// Elements in tuple is tensor, which shape is dynamic.
if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
min_shape_tuple[i] = out[ATTR_MIN_SHAPE];
@ -296,7 +312,13 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
}
dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple;
dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());
if (is_build_value) {
dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());
} else {
dic[ATTR_VALUE] = value_tuple;
dic[ATTR_MIN_VALUE] = min_value_tuple;
dic[ATTR_MAX_VALUE] = max_value_tuple;
}
if (dyn_shape) {
dic[ATTR_MIN_SHAPE] = min_shape_tuple;
@ -359,6 +381,14 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_MAX_SHAPE] = max_shape;
}
}
auto min_value = arg_tensor->get_min_value();
auto max_value = arg_tensor->get_max_value();
if (min_value != nullptr && max_value != nullptr) {
dic[ATTR_MIN_VALUE] = BuildValue(min_value);
dic[ATTR_MAX_VALUE] = BuildValue(max_value);
}
dic[ATTR_DTYPE] = arg_tensor->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
} else if (abs_base->isa<AbstractRowTensor>()) {
@ -446,12 +476,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
auto out_dtype = output[ATTR_DTYPE];
if (output[ATTR_VALUE].is_none()) {
auto out_shape = output[ATTR_SHAPE];
py::object min_shape =
output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
py::object max_shape =
output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape);
return PyListDtype2AbstractTensor(out_shape, out_dtype, output);
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
@ -466,6 +491,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
// Replace to tensor constant node in specialize
auto res_tensor = res_spec->cast<AbstractTensorPtr>();
res_tensor->set_value(converted_ret);
SetValueRange(res_tensor, output);
}
if (prim_py->IsCustomPrim()) {
// Raise error if output_num is not match the infer result.

View File

@ -315,8 +315,70 @@ py::object VectorRefToPyData(const VectorRef &value_list) {
return ret;
}
void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
if (output.is_none()) {
return;
}
py::object obj_min =
output.contains(py::str(ATTR_MIN_VALUE)) ? (py::object)output[ATTR_MIN_VALUE] : (py::object)py::none();
py::object obj_max =
output.contains(py::str(ATTR_MAX_VALUE)) ? (py::object)output[ATTR_MAX_VALUE] : (py::object)py::none();
if (!obj_min.is_none() && !obj_max.is_none()) {
bool converted = true;
ValuePtr min_value = nullptr;
ValuePtr max_value = nullptr;
converted = parse::ConvertData(obj_min, &min_value);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert shape min value data failed";
}
converted = parse::ConvertData(obj_max, &max_value);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert shape max value data failed";
}
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
abs_tensor->set_value_range(min_value, max_value);
}
}
AbstractBasePtr PyList2DynamicShapeTensor(const py::object &shape_obj, const py::object &type_obj,
const py::object &output) {
AbstractBasePtr tensor = nullptr;
auto ret_vec = shape_obj.cast<ShapeVector>();
auto ret_dtype = type_obj.cast<TypePtr>();
ShapeVector min_shape_vec;
ShapeVector max_shape_vec;
if (!output.is_none()) {
py::object min_shape =
output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
py::object max_shape =
output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
if (!min_shape.is_none()) {
min_shape_vec = min_shape.cast<ShapeVector>();
}
if (!max_shape.is_none()) {
max_shape_vec = max_shape.cast<ShapeVector>();
}
}
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
if (ret_dtype->isa<TensorType>()) {
auto tensor_type = type_obj.cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
} else {
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
}
SetValueRange(tensor, output);
return tensor;
}
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
const py::object &min_shape, const py::object &max_shape) {
const py::object &output) {
if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
auto ret_vec = shape_obj.cast<ShapeVector>();
auto ret_dtype = type_obj.cast<TypePtr>();
@ -326,26 +388,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
return abs_scalar;
}
AbstractBasePtr tensor = nullptr;
ShapeVector min_shape_vec;
ShapeVector max_shape_vec;
if (!min_shape.is_none()) {
min_shape_vec = min_shape.cast<ShapeVector>();
}
if (!max_shape.is_none()) {
max_shape_vec = max_shape.cast<ShapeVector>();
}
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
if (ret_dtype->isa<TensorType>()) {
auto tensor_type = type_obj.cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
} else {
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
}
return tensor;
return PyList2DynamicShapeTensor(shape_obj, type_obj, output);
} else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
py::tuple shape_tuple = shape_obj.cast<py::tuple>();
py::tuple typeid_tuple = type_obj.cast<py::tuple>();

View File

@ -37,8 +37,8 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
const std::shared_ptr<py::object> &ret_val);
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
const py::object &min_shape = py::none(),
const py::object &max_shape = py::none());
const py::object &output = py::none());
void SetValueRange(const AbstractBasePtr &tensor, const py::object &output);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_PY_H_

View File

@ -508,6 +508,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_value_range(get_min_value(), get_max_value());
return clone;
}

View File

@ -28,6 +28,7 @@
#include "utils/log_adapter.h"
#include "utils/hashing.h"
#include "utils/any.h"
#include "utils/flags.h"
#include "base/base.h"
#include "ir/dtype.h"
#include "ir/value.h"
@ -289,6 +290,13 @@ class AbstractTensor : public AbstractUndetermined {
~AbstractTensor() override = default;
MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined)
void set_value_range(const ValuePtr &min_value, const ValuePtr &max_value) {
min_value_ = min_value;
max_value_ = max_value;
}
const ValuePtr &get_min_value() const { return min_value_; }
const ValuePtr &get_max_value() const { return max_value_; }
TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
AbstractBasePtr Clone() const override;
@ -312,6 +320,8 @@ class AbstractTensor : public AbstractUndetermined {
protected:
bool equal_to(const AbstractTensor &other) const;
ValuePtr min_value_ = nullptr;
ValuePtr max_value_ = nullptr;
};
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;

View File

@ -726,7 +726,11 @@ AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const Primitive
ShapeVector tensor_shp({static_cast<int64_t>(shape.size())});
if (has_dyn_shape) {
auto elem = std::make_shared<AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(64));
return std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp));
auto min_value = MakeValue(input->shape()->min_shape());
auto max_value = MakeValue(input->shape()->max_shape());
auto tensor = std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp));
tensor->set_value_range(min_value, max_value);
return tensor;
}
auto shp_buf_size = sizeof(int64_t) * shape.size();
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size);

View File

@ -38,4 +38,6 @@ const char ATTR_DTYPE[] = "dtype";
const char ATTR_SHAPE[] = "shape";
const char ATTR_MIN_SHAPE[] = "min_shape";
const char ATTR_MAX_SHAPE[] = "max_shape";
const char ATTR_MIN_VALUE[] = "min_value";
const char ATTR_MAX_VALUE[] = "max_value";
} // namespace mindspore

View File

@ -33,6 +33,8 @@ extern const char ATTR_DTYPE[];
extern const char ATTR_SHAPE[];
extern const char ATTR_MIN_SHAPE[];
extern const char ATTR_MAX_SHAPE[];
extern const char ATTR_MIN_VALUE[];
extern const char ATTR_MAX_VALUE[];
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_FLAGS_H

View File

@ -98,18 +98,40 @@ def test_gatherv2():
super(Net, self).__init__()
self.unq = P.Unique()
self.gather = P.GatherV2()
self.yy = Tensor(np.ones([8], dtype=np.int32))
def construct(self, x, y):
shp = P.Shape()(self.yy)
y = P.Reshape()(y, shp)
u, _ = self.unq(y)
u_shp = P.DynamicShape()(u)
z = self.gather(x, u, 0)
return z
return z, u_shp
x = Tensor(np.ones([20, 12], dtype=np.float32))
y = Tensor(np.ones([8], dtype=np.int32))
y = Tensor(np.ones([2, 4], dtype=np.int32))
net = Net()
net(x, y)
def test_segmentsum():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.unq = P.Unique()
self.segment_ids = Tensor([0, 0, 1, 2, 1, 1, 1, 1], mstype.int32)
self.sum = P.UnsortedSegmentSum()
def construct(self, x):
u, _ = self.unq(x)
shp = P.DynamicShape()(u)
z = self.sum(x, self.segment_ids, shp[0])
return z, shp[0]
x = Tensor(np.ones([8], dtype=np.int32))
net = Net()
net(x)
def test_addn():
class Net(nn.Cell):
def __init__(self):