!45037 tranfsfer Reshape's infer_shape from Python to core/ops
Merge pull request !45037 from jxlang910/master
This commit is contained in:
commit
c27d9389a3
|
@ -178,8 +178,6 @@ AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &prim
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -623,125 +623,6 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
|
||||
}
|
||||
|
||||
static ShapeVector GetShape(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list,
|
||||
const std::string &op_name) {
|
||||
ShapeVector shape;
|
||||
if (args_spec_list.size() == kSizeTwo) {
|
||||
auto input_value = args_spec_list[1]->BuildValue();
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
shape = CheckAndConvertUtils::CheckTensorIntValue("shape", input_value, op_name);
|
||||
} else {
|
||||
shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", input_value, op_name);
|
||||
}
|
||||
} else {
|
||||
ValuePtr sh = primitive->GetAttr("shape");
|
||||
MS_EXCEPTION_IF_NULL(sh);
|
||||
if (sh->isa<ValueTuple>()) {
|
||||
auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(reshape_value_tuple);
|
||||
auto reshape_tuple = reshape_value_tuple->value();
|
||||
(void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
} else if (sh->isa<tensor::Tensor>()) {
|
||||
shape = CheckAndConvertUtils::CheckTensorIntValue("shape", sh, "Reshape");
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "In stage of execution, the primitive[Reshape]'s input['shape'] must be a tuple or "
|
||||
<< "constant Tensor.";
|
||||
}
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
// padding_axis_value is for the condition that the number of -1 in shape > 1, , but just paddingshape
|
||||
std::vector<int> padding_axis_value;
|
||||
ValuePtr padding_axis = primitive->GetAttr("reshape_padding_axis");
|
||||
if (padding_axis != nullptr) {
|
||||
padding_axis_value = GetValue<std::vector<int>>(padding_axis);
|
||||
}
|
||||
|
||||
ShapeVector shape = GetShape(primitive, args_spec_list, op_name);
|
||||
ShapeVector x_shape = x->shape()->shape();
|
||||
ShapeVector x_max_shape = x->shape()->max_shape();
|
||||
ShapeVector x_min_shape = x->shape()->min_shape();
|
||||
if (x_max_shape.empty()) {
|
||||
x_max_shape = x_shape;
|
||||
}
|
||||
if (x_min_shape.empty()) {
|
||||
x_min_shape = x_shape;
|
||||
}
|
||||
|
||||
auto max_shape = shape;
|
||||
auto min_shape = shape;
|
||||
int64_t x_num = 1;
|
||||
int64_t x_min_num = 1;
|
||||
int64_t x_max_num = 1;
|
||||
for (int64_t value : x_shape) {
|
||||
x_num = LongMulWithOverflowCheck(value, x_num);
|
||||
}
|
||||
for (int64_t value : x_min_shape) {
|
||||
x_min_num = LongMulWithOverflowCheck(value, x_min_num);
|
||||
}
|
||||
for (int64_t value : x_max_shape) {
|
||||
x_max_num = LongMulWithOverflowCheck(value, x_max_num);
|
||||
}
|
||||
|
||||
auto it_first = find(shape.begin(), shape.end(), -1);
|
||||
if (it_first != shape.end()) {
|
||||
if (!padding_axis_value.empty()) {
|
||||
// the condition that the number of -1 in shape is > 1, but just paddingshape
|
||||
for (size_t index = 0; index < padding_axis_value.size(); ++index) {
|
||||
shape[IntToSize(padding_axis_value[index])] = x_shape[index];
|
||||
min_shape[IntToSize(padding_axis_value[index])] = x_min_shape[index];
|
||||
max_shape[IntToSize(padding_axis_value[index])] = x_max_shape[index];
|
||||
}
|
||||
} else {
|
||||
auto it_second = find(it_first + 1, shape.end(), -1);
|
||||
if (it_second != shape.end()) {
|
||||
MS_LOG(EXCEPTION) << "At most one component of input shape can be -1, but got " << shape;
|
||||
}
|
||||
auto index = LongToSize(std::distance(shape.begin(), it_first));
|
||||
int64_t infer_value = x_num;
|
||||
int64_t infer_min_value = x_min_num;
|
||||
int64_t infer_max_value = x_max_num;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
int64_t value = shape[i];
|
||||
if (value != -1 && value != 0) {
|
||||
infer_value = infer_value / value;
|
||||
infer_min_value = infer_min_value / value;
|
||||
infer_max_value = infer_max_value / value;
|
||||
}
|
||||
}
|
||||
shape[index] = infer_value;
|
||||
min_shape[index] = infer_min_value;
|
||||
max_shape[index] = infer_max_value;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t shape_num = 1;
|
||||
for (int64_t value : shape) {
|
||||
shape_num = LongMulWithOverflowCheck(value, shape_num);
|
||||
}
|
||||
if (shape_num != x_num) {
|
||||
MS_LOG(EXCEPTION) << "The accumulate of x_shape must be equal to out_shape, but got x_shape: " << x_shape
|
||||
<< ", and out_shape: " << shape;
|
||||
}
|
||||
|
||||
if (IsDynamic(min_shape) || IsDynamic(max_shape)) {
|
||||
min_shape.clear();
|
||||
max_shape.clear();
|
||||
}
|
||||
|
||||
AbstractTensorPtr ret =
|
||||
std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: one tensor.
|
||||
|
|
|
@ -400,7 +400,6 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
{prim::kPrimTransposeNOD, R{InferImplTranspose, nullptr, true}},
|
||||
{prim::kPrimStridedSlice, R{ops::StridedSliceInfer, nullptr, true}},
|
||||
{prim::kPrimSlice, R{ops::SliceInfer, nullptr, true}},
|
||||
{prim::kPrimReshape, R{InferImplReshape, nullptr, true}},
|
||||
{prim::kPrimSliceGrad, R{ops::SliceGradInfer, nullptr, true}},
|
||||
{prim::kPrimConcat, R{InferImplConcat, nullptr, true}},
|
||||
{prim::kPrimConcatOffset, R{InferImplConcatOffset, nullptr, true}},
|
||||
|
|
|
@ -14,17 +14,151 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/reshape.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
ShapeVector update_shape(const std::vector<int> &padding_axis_value, const ShapeVector &x_shape, ShapeVector shape) {
|
||||
// padding_axis_value is for the condition that the number of -1 in shape > 1, , but just paddingshape
|
||||
if (std::any_of(x_shape.begin(), x_shape.end(), [](const int &shape_i) { return shape_i < 0; })) {
|
||||
return shape;
|
||||
}
|
||||
int64_t x_num = 1;
|
||||
for (int64_t value : x_shape) {
|
||||
x_num = LongMulWithOverflowCheck(value, x_num);
|
||||
}
|
||||
|
||||
auto it_first = find(shape.begin(), shape.end(), -1);
|
||||
if (it_first != shape.end()) {
|
||||
if (!padding_axis_value.empty()) {
|
||||
// the condition that the number of -1 in shape is > 1, but just paddingshape
|
||||
for (size_t index = 0; index < padding_axis_value.size(); ++index) {
|
||||
shape[IntToSize(padding_axis_value[index])] = x_shape[index];
|
||||
}
|
||||
} else {
|
||||
auto it_second = find(it_first + 1, shape.end(), -1);
|
||||
if (it_second != shape.end()) {
|
||||
MS_EXCEPTION(ValueError) << "At most one component of input shape can be -1, but got " << shape;
|
||||
}
|
||||
auto index = LongToSize(std::distance(shape.begin(), it_first));
|
||||
int64_t infer_value = x_num;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
int64_t value = shape[i];
|
||||
if (value != -1 && value != 0) {
|
||||
infer_value = infer_value / value;
|
||||
}
|
||||
}
|
||||
shape[index] = infer_value;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t shape_num = 1;
|
||||
for (int64_t value : shape) {
|
||||
shape_num = LongMulWithOverflowCheck(value, shape_num);
|
||||
}
|
||||
if (shape_num != x_num) {
|
||||
MS_EXCEPTION(ValueError) << "The accumulate of x_shape must be equal to out_shape, but got x_shape: " << x_shape
|
||||
<< ", and out_shape: " << shape;
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Reshape, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNameReshape, Reshape);
|
||||
class ReshapeInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
constexpr size_t max_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckValue<size_t>("input size", input_args.size(), kLessEqual, max_size, prim_name);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
||||
std::vector<int64_t> output_shape;
|
||||
|
||||
if (input_args.size() == max_size) {
|
||||
auto input_y = input_args[1];
|
||||
MS_EXCEPTION_IF_NULL(input_y);
|
||||
if (input_y->isa<abstract::AbstractTensor>()) {
|
||||
auto y_value = input_y->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(y_value);
|
||||
if (y_value != kAnyValue) {
|
||||
output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", y_value, prim_name);
|
||||
} else {
|
||||
abstract::ShapePtr y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
|
||||
auto shape_value = y_shape->shape();
|
||||
if (shape_value.size() != 1) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the shape size must be 1, but got: " << shape_value.size() << ".";
|
||||
}
|
||||
if (y_shape->IsDynamic()) {
|
||||
output_shape.push_back(abstract::Shape::kShapeRankAny);
|
||||
} else {
|
||||
output_shape = GetShapeValue(primitive, input_y);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
} else if (input_y->isa<abstract::AbstractTuple>()) {
|
||||
auto y_value = input_y->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(y_value);
|
||||
output_shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", y_value, primitive->name());
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "input_y must be AbstractTensor or AbstractTuple, but got: " << input_y;
|
||||
}
|
||||
} else {
|
||||
// When the shape is passed as an attribute, shape should be constant.
|
||||
ValuePtr sh = primitive->GetAttr("shape");
|
||||
MS_EXCEPTION_IF_NULL(sh);
|
||||
if (sh->isa<ValueTuple>()) {
|
||||
auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(reshape_value_tuple);
|
||||
auto reshape_tuple = reshape_value_tuple->value();
|
||||
(void)std::transform(reshape_tuple.begin(), reshape_tuple.end(), std::back_inserter(output_shape),
|
||||
[=](const ValuePtr &e) -> int64_t {
|
||||
if (!e->isa<Int64Imm>()) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "For primitive[" << prim_name << "], the 'shape'"
|
||||
<< " must be a tuple with all Int elements, but got " << sh->ToString();
|
||||
}
|
||||
return GetValue<int64_t>(e);
|
||||
});
|
||||
} else if (sh->isa<tensor::Tensor>()) {
|
||||
output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", sh, "Reshape");
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "In stage of execution, the primitive[Reshape]'s input['shape'] must be a tuple or "
|
||||
<< "constant Tensor.";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> padding_axis_value;
|
||||
ValuePtr padding_axis = primitive->GetAttr("reshape_padding_axis");
|
||||
if (padding_axis != nullptr) {
|
||||
padding_axis_value = GetValue<std::vector<int>>(padding_axis);
|
||||
}
|
||||
ShapeVector updated_shape = update_shape(padding_axis_value, x_shape, output_shape);
|
||||
return std::make_shared<abstract::Shape>(updated_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
std::set<TypePtr> template_types = {kTensorType};
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
|
||||
return x_dtype;
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,9 +37,6 @@ class MIND_API Reshape : public BaseOperator {
|
|||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -752,7 +752,7 @@ class Col2Im(Primitive):
|
|||
self.add_prim_attr('stride', self.stride)
|
||||
|
||||
|
||||
class Reshape(PrimitiveWithInfer):
|
||||
class Reshape(PrimitiveWithCheck):
|
||||
"""
|
||||
Rearranges the input Tensor based on the given shape.
|
||||
|
||||
|
@ -776,85 +776,46 @@ class Reshape(PrimitiveWithInfer):
|
|||
"""Initialize Reshape"""
|
||||
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
|
||||
|
||||
@staticmethod
|
||||
def _get_shape_and_range(x, shape):
|
||||
""" get min and max shape when output shape is dynamic"""
|
||||
x_shp = x['shape']
|
||||
if is_shape_unknown(shape['shape']):
|
||||
out_shape = [-2]
|
||||
return out_shape
|
||||
|
||||
shape_rank = shape['shape'][0]
|
||||
if not x_shp:
|
||||
# x is a scalar, output shape fixed
|
||||
out_shape = [1] * shape_rank
|
||||
return out_shape
|
||||
|
||||
out_shape = [-1] * shape_rank
|
||||
|
||||
if "shape_value" in shape:
|
||||
out_shape = shape["shape_value"]
|
||||
return out_shape
|
||||
|
||||
def _update_shape_and_value(self, out, x, shape_v, dim_prod, neg_index):
|
||||
""" update shape, value and min / max value of output when input shape is known"""
|
||||
x_shp = x['shape']
|
||||
if dim_prod <= 0:
|
||||
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
|
||||
f"the value of 'input_shape' is {shape_v}. "
|
||||
f"The product of 'input_shape' should > 0, but got {dim_prod}.")
|
||||
arr_prod = np.prod(x_shp)
|
||||
if neg_index != -1:
|
||||
shape_v[neg_index] = int(arr_prod // dim_prod)
|
||||
dim_prod *= shape_v[neg_index]
|
||||
if dim_prod != arr_prod:
|
||||
raise ValueError(f"For '{self.name}', the product of the 'input_x' shape "
|
||||
f"should be equal to product of 'input_shape', but got product of the"
|
||||
f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.")
|
||||
out['shape'] = tuple(shape_v)
|
||||
|
||||
if x['value'] is not None:
|
||||
out['value'] = Tensor(x['value'].asnumpy().reshape(shape_v))
|
||||
|
||||
return out
|
||||
|
||||
def __infer__(self, x, shape):
|
||||
shape_v = shape['value']
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
def infer_value(self, x, shape):
|
||||
"""infer value"""
|
||||
# for shape is not constant
|
||||
if shape_v is None:
|
||||
out_shape = self._get_shape_and_range(x, shape)
|
||||
return {
|
||||
'shape': out_shape,
|
||||
'dtype': x['dtype'],
|
||||
'value': None,
|
||||
}
|
||||
|
||||
if isinstance(shape_v, Tensor_):
|
||||
validator.check_tensor_dtype_valid("shape", shape['dtype'], [mstype.int32, mstype.int64], self.name)
|
||||
shape_v = shape_v.asnumpy().tolist()
|
||||
if shape is None or x is None:
|
||||
return None
|
||||
if isinstance(shape, Tensor_):
|
||||
validator.check_tensor_dtype_valid("shape", mstype.tensor_type(shape.dtype),
|
||||
[mstype.int32, mstype.int64], self.name)
|
||||
shape = shape.asnumpy().tolist()
|
||||
else:
|
||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||
shape_v = list(shape_v)
|
||||
validator.check_value_type("shape", shape, [tuple], self.name)
|
||||
shape = list(shape)
|
||||
|
||||
neg_index = -1
|
||||
dim_prod = 1
|
||||
for i, shp_i in enumerate(shape_v):
|
||||
for i, shp_i in enumerate(shape):
|
||||
validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name)
|
||||
if shp_i == -1:
|
||||
if neg_index != -1:
|
||||
raise ValueError(f"For '{self.name}', there can be at most one '-1' in 'input_shape', "
|
||||
f"but got {shape_v}.")
|
||||
f"but got {shape}.")
|
||||
neg_index = i
|
||||
else:
|
||||
dim_prod *= shp_i
|
||||
|
||||
out = {'shape': shape_v,
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
||||
if not is_shape_unknown(x['shape']):
|
||||
out = self._update_shape_and_value(out, x, shape_v, dim_prod, neg_index)
|
||||
out = None
|
||||
if not is_shape_unknown(x.shape):
|
||||
x_shp = x.shape
|
||||
if dim_prod <= 0:
|
||||
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
|
||||
f"the value of 'input_shape' is {shape}. "
|
||||
f"The product of 'input_shape' should > 0, but got {dim_prod}.")
|
||||
arr_prod = np.prod(x_shp)
|
||||
if neg_index != -1:
|
||||
shape[neg_index] = int(arr_prod // dim_prod)
|
||||
dim_prod *= shape[neg_index]
|
||||
if dim_prod != arr_prod:
|
||||
raise ValueError(f"For '{self.name}', the product of the 'input_x' shape "
|
||||
f"should be equal to product of 'input_shape', but got product of the"
|
||||
f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.")
|
||||
out = Tensor(x.asnumpy().reshape(shape))
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue