forked from mindspore-Ecosystem/mindspore
!41491 add dynamic_shape for reshape ops
Merge pull request !41491 from jxlang910/master
This commit is contained in:
commit
88573c5fc6
|
@ -186,8 +186,6 @@ AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &prim
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -740,125 +740,6 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
|
||||
}
|
||||
|
||||
static ShapeVector GetShape(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list,
|
||||
const std::string &op_name) {
|
||||
ShapeVector shape;
|
||||
if (args_spec_list.size() == kSizeTwo) {
|
||||
auto input_value = args_spec_list[1]->BuildValue();
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
shape = CheckAndConvertUtils::CheckTensorIntValue("shape", input_value, op_name);
|
||||
} else {
|
||||
shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", input_value, op_name);
|
||||
}
|
||||
} else {
|
||||
ValuePtr sh = primitive->GetAttr("shape");
|
||||
MS_EXCEPTION_IF_NULL(sh);
|
||||
if (sh->isa<ValueTuple>()) {
|
||||
auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(reshape_value_tuple);
|
||||
auto reshape_tuple = reshape_value_tuple->value();
|
||||
(void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
} else if (sh->isa<tensor::Tensor>()) {
|
||||
shape = CheckAndConvertUtils::CheckTensorIntValue("shape", sh, "Reshape");
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "In stage of execution, the primitive[Reshape]'s input['shape'] must be a tuple or "
|
||||
<< "constant Tensor.";
|
||||
}
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
// padding_axis_value is for the condition that the number of -1 in shape > 1, , but just paddingshape
|
||||
std::vector<int> padding_axis_value;
|
||||
ValuePtr padding_axis = primitive->GetAttr("reshape_padding_axis");
|
||||
if (padding_axis != nullptr) {
|
||||
padding_axis_value = GetValue<std::vector<int>>(padding_axis);
|
||||
}
|
||||
|
||||
ShapeVector shape = GetShape(primitive, args_spec_list, op_name);
|
||||
ShapeVector x_shape = x->shape()->shape();
|
||||
ShapeVector x_max_shape = x->shape()->max_shape();
|
||||
ShapeVector x_min_shape = x->shape()->min_shape();
|
||||
if (x_max_shape.empty()) {
|
||||
x_max_shape = x_shape;
|
||||
}
|
||||
if (x_min_shape.empty()) {
|
||||
x_min_shape = x_shape;
|
||||
}
|
||||
|
||||
auto max_shape = shape;
|
||||
auto min_shape = shape;
|
||||
int64_t x_num = 1;
|
||||
int64_t x_min_num = 1;
|
||||
int64_t x_max_num = 1;
|
||||
for (int64_t value : x_shape) {
|
||||
x_num = LongMulWithOverflowCheck(value, x_num);
|
||||
}
|
||||
for (int64_t value : x_min_shape) {
|
||||
x_min_num = LongMulWithOverflowCheck(value, x_min_num);
|
||||
}
|
||||
for (int64_t value : x_max_shape) {
|
||||
x_max_num = LongMulWithOverflowCheck(value, x_max_num);
|
||||
}
|
||||
|
||||
auto it_first = find(shape.begin(), shape.end(), -1);
|
||||
if (it_first != shape.end()) {
|
||||
if (!padding_axis_value.empty()) {
|
||||
// the condition that the number of -1 in shape is > 1, but just paddingshape
|
||||
for (size_t index = 0; index < padding_axis_value.size(); ++index) {
|
||||
shape[IntToSize(padding_axis_value[index])] = x_shape[index];
|
||||
min_shape[IntToSize(padding_axis_value[index])] = x_min_shape[index];
|
||||
max_shape[IntToSize(padding_axis_value[index])] = x_max_shape[index];
|
||||
}
|
||||
} else {
|
||||
auto it_second = find(it_first + 1, shape.end(), -1);
|
||||
if (it_second != shape.end()) {
|
||||
MS_LOG(EXCEPTION) << "At most one component of input shape can be -1, but got " << shape;
|
||||
}
|
||||
auto index = LongToSize(std::distance(shape.begin(), it_first));
|
||||
int64_t infer_value = x_num;
|
||||
int64_t infer_min_value = x_min_num;
|
||||
int64_t infer_max_value = x_max_num;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
int64_t value = shape[i];
|
||||
if (value != -1 && value != 0) {
|
||||
infer_value = infer_value / value;
|
||||
infer_min_value = infer_min_value / value;
|
||||
infer_max_value = infer_max_value / value;
|
||||
}
|
||||
}
|
||||
shape[index] = infer_value;
|
||||
min_shape[index] = infer_min_value;
|
||||
max_shape[index] = infer_max_value;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t shape_num = 1;
|
||||
for (int64_t value : shape) {
|
||||
shape_num = LongMulWithOverflowCheck(value, shape_num);
|
||||
}
|
||||
if (shape_num != x_num) {
|
||||
MS_LOG(EXCEPTION) << "The accumulate of x_shape must be equal to out_shape, but got x_shape: " << x_shape
|
||||
<< ", and out_shape: " << shape;
|
||||
}
|
||||
|
||||
if (IsDynamic(min_shape) || IsDynamic(max_shape)) {
|
||||
min_shape.clear();
|
||||
max_shape.clear();
|
||||
}
|
||||
|
||||
AbstractTensorPtr ret =
|
||||
std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: one tensor.
|
||||
|
|
|
@ -391,7 +391,6 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
{prim::kPrimStridedSlice, R{ops::StridedSliceInfer, nullptr, true}},
|
||||
{prim::kPrimSlice, R{ops::SliceInfer, nullptr, true}},
|
||||
{prim::kPrimSliceGrad, R{ops::SliceGradInfer, nullptr, true}},
|
||||
{prim::kPrimReshape, R{InferImplReshape, nullptr, true}},
|
||||
{prim::kPrimConcat, R{InferImplConcat, nullptr, true}},
|
||||
{prim::kPrimConcatOffset, R{InferImplConcatOffset, nullptr, true}},
|
||||
{prim::kPrimTransData, R{InferImplTransData, nullptr, true}},
|
||||
|
|
|
@ -14,17 +14,155 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/reshape.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
ShapeVector update_shape(std::vector<int> padding_axis_value, ShapeVector x_shape, ShapeVector shape) {
|
||||
// padding_axis_value is for the condition that the number of -1 in shape > 1, , but just paddingshape
|
||||
if (std::any_of(x_shape.begin(), x_shape.end(), [](const int &shape_i) { return shape_i < 0; })) {
|
||||
return shape;
|
||||
}
|
||||
int64_t x_num = 1;
|
||||
for (int64_t value : x_shape) {
|
||||
x_num = LongMulWithOverflowCheck(value, x_num);
|
||||
}
|
||||
|
||||
auto it_first = find(shape.begin(), shape.end(), -1);
|
||||
if (it_first != shape.end()) {
|
||||
if (!padding_axis_value.empty()) {
|
||||
// the condition that the number of -1 in shape is > 1, but just paddingshape
|
||||
for (size_t index = 0; index < padding_axis_value.size(); ++index) {
|
||||
shape[IntToSize(padding_axis_value[index])] = x_shape[index];
|
||||
}
|
||||
} else {
|
||||
auto it_second = find(it_first + 1, shape.end(), -1);
|
||||
if (it_second != shape.end()) {
|
||||
MS_EXCEPTION(ValueError) << "At most one component of input shape can be -1, but got " << shape;
|
||||
}
|
||||
auto index = LongToSize(std::distance(shape.begin(), it_first));
|
||||
int64_t infer_value = x_num;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
int64_t value = shape[i];
|
||||
if (value != -1 && value != 0) {
|
||||
infer_value = infer_value / value;
|
||||
}
|
||||
}
|
||||
shape[index] = infer_value;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t shape_num = 1;
|
||||
for (int64_t value : shape) {
|
||||
shape_num = LongMulWithOverflowCheck(value, shape_num);
|
||||
}
|
||||
if (shape_num != x_num) {
|
||||
MS_EXCEPTION(ValueError) << "The accumulate of x_shape must be equal to out_shape, but got x_shape: " << x_shape
|
||||
<< ", and out_shape: " << shape;
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Reshape, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNameReshape, Reshape);
|
||||
class ReshapeInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
||||
std::vector<int64_t> output_shape;
|
||||
size_t normal_inputs_number = 2;
|
||||
if (input_args.size() == normal_inputs_number) {
|
||||
auto input_y = input_args[1];
|
||||
MS_EXCEPTION_IF_NULL(input_y);
|
||||
auto y_value = input_y->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(y_value);
|
||||
if (input_y->isa<abstract::AbstractTensor>()) {
|
||||
if (y_value->isa<tensor::Tensor>()) {
|
||||
output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", y_value, prim_name);
|
||||
} else {
|
||||
abstract::ShapePtr y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
|
||||
auto shape_value = y_shape->shape();
|
||||
if (shape_value.size() != 1) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the shape size must be 1, but got: " << shape_value.size() << ".";
|
||||
}
|
||||
if (y_shape->IsDynamic()) {
|
||||
output_shape.push_back(UNKNOWN_RANK);
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
} else if (x_shape.size() == 0) {
|
||||
MS_LOG(DEBUG) << "x is a scalar.";
|
||||
MS_LOG(DEBUG) << "the size of 'shape' is: " << shape_value[0];
|
||||
for (int i = 0; i < SizeToInt(shape_value[0]); i++) {
|
||||
output_shape.push_back(1);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
} else {
|
||||
auto y_tensor = input_y->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_shape_value = y_tensor->get_shape_value();
|
||||
if (tensor_shape_value == nullptr) {
|
||||
output_shape = ShapeVector(shape_value[0], UNKNOWN_DIM);
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
} else {
|
||||
auto shape_vector = GetValue<ShapeVector>(tensor_shape_value);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(LongToSize(shape_value[0]) == shape_vector.size(),
|
||||
"Illegal shape of shape value");
|
||||
output_shape = shape_vector;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (input_y->isa<abstract::AbstractTuple>()) {
|
||||
output_shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", y_value, primitive->name());
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "input_y must be AbstractTensor or AbstractTuple, but got: " << input_y;
|
||||
}
|
||||
} else {
|
||||
ValuePtr sh = primitive->GetAttr("shape");
|
||||
MS_EXCEPTION_IF_NULL(sh);
|
||||
if (sh->isa<ValueTuple>()) {
|
||||
auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(reshape_value_tuple);
|
||||
auto reshape_tuple = reshape_value_tuple->value();
|
||||
(void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(output_shape),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
} else if (sh->isa<tensor::Tensor>()) {
|
||||
output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", sh, "Reshape");
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "In stage of execution, the primitive[Reshape]'s input['shape'] must be a tuple or "
|
||||
<< "constant Tensor.";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> padding_axis_value;
|
||||
ValuePtr padding_axis = primitive->GetAttr("reshape_padding_axis");
|
||||
if (padding_axis != nullptr) {
|
||||
padding_axis_value = GetValue<std::vector<int>>(padding_axis);
|
||||
}
|
||||
ShapeVector updated_shape = update_shape(padding_axis_value, x_shape, output_shape);
|
||||
return std::make_shared<abstract::Shape>(updated_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
std::set<TypePtr> template_types = {kTensorType};
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
|
||||
return x_dtype;
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,9 +37,6 @@ class MIND_API Reshape : public BaseOperator {
|
|||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -109,6 +109,7 @@ class UnravelIndex(Primitive):
|
|||
[[0 2]
|
||||
[1 2]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Shape"""
|
||||
|
@ -605,6 +606,7 @@ class Im2Col(Primitive):
|
|||
>>> print(y.shape)
|
||||
(4, 36, 30, 30)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, ksizes, strides=1, dilations=1, padding_mode="CALCULATED", pads=0):
|
||||
"""Initialize Im2Col."""
|
||||
|
@ -749,7 +751,7 @@ class Col2Im(Primitive):
|
|||
self.add_prim_attr('stride', self.stride)
|
||||
|
||||
|
||||
class Reshape(PrimitiveWithInfer):
|
||||
class Reshape(PrimitiveWithCheck):
|
||||
"""
|
||||
Rearranges the input Tensor based on the given shape.
|
||||
|
||||
|
@ -773,85 +775,45 @@ class Reshape(PrimitiveWithInfer):
|
|||
"""Initialize Reshape"""
|
||||
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
|
||||
|
||||
@staticmethod
|
||||
def _get_shape_and_range(x, shape):
|
||||
""" get min and max shape when output shape is dynamic"""
|
||||
x_shp = x['shape']
|
||||
if is_shape_unknown(shape['shape']):
|
||||
out_shape = [-2]
|
||||
return out_shape
|
||||
|
||||
shape_rank = shape['shape'][0]
|
||||
if not x_shp:
|
||||
# x is a scalar, output shape fixed
|
||||
out_shape = [1] * shape_rank
|
||||
return out_shape
|
||||
|
||||
out_shape = [-1] * shape_rank
|
||||
|
||||
if "shape_value" in shape:
|
||||
out_shape = shape["shape_value"]
|
||||
return out_shape
|
||||
|
||||
def _update_shape_and_value(self, out, x, shape_v, dim_prod, neg_index):
|
||||
""" update shape, value and min / max value of output when input shape is known"""
|
||||
x_shp = x['shape']
|
||||
if dim_prod <= 0:
|
||||
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
|
||||
f"the value of 'input_shape' is {shape_v}. "
|
||||
f"The product of 'input_shape' should > 0, but got {dim_prod}.")
|
||||
arr_prod = np.prod(x_shp)
|
||||
if neg_index != -1:
|
||||
shape_v[neg_index] = int(arr_prod // dim_prod)
|
||||
dim_prod *= shape_v[neg_index]
|
||||
if dim_prod != arr_prod:
|
||||
raise ValueError(f"For '{self.name}', the product of the 'input_x' shape "
|
||||
f"should be equal to product of 'input_shape', but got product of the"
|
||||
f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.")
|
||||
out['shape'] = tuple(shape_v)
|
||||
|
||||
if x['value'] is not None:
|
||||
out['value'] = Tensor(x['value'].asnumpy().reshape(shape_v))
|
||||
|
||||
return out
|
||||
|
||||
def __infer__(self, x, shape):
|
||||
shape_v = shape['value']
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
def infer_value(self, x, shape):
|
||||
"""infer value"""
|
||||
# for shape is not constant
|
||||
if shape_v is None:
|
||||
out_shape = self._get_shape_and_range(x, shape)
|
||||
return {
|
||||
'shape': out_shape,
|
||||
'dtype': x['dtype'],
|
||||
'value': None,
|
||||
}
|
||||
|
||||
if isinstance(shape_v, Tensor_):
|
||||
validator.check_tensor_dtype_valid("shape", shape['dtype'], [mstype.int32, mstype.int64], self.name)
|
||||
shape_v = shape_v.asnumpy().tolist()
|
||||
if shape is None or x is None:
|
||||
return None
|
||||
if isinstance(shape, Tensor_):
|
||||
validator.check_tensor_dtype_valid("shape", shape.dtype, [mstype.int32, mstype.int64], self.name)
|
||||
shape = shape.asnumpy().tolist()
|
||||
else:
|
||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||
shape_v = list(shape_v)
|
||||
validator.check_value_type("shape", shape, [tuple], self.name)
|
||||
shape = list(shape)
|
||||
|
||||
neg_index = -1
|
||||
dim_prod = 1
|
||||
for i, shp_i in enumerate(shape_v):
|
||||
for i, shp_i in enumerate(shape):
|
||||
validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name)
|
||||
if shp_i == -1:
|
||||
if neg_index != -1:
|
||||
raise ValueError(f"For '{self.name}', there can be at most one '-1' in 'input_shape', "
|
||||
f"but got {shape_v}.")
|
||||
f"but got {shape}.")
|
||||
neg_index = i
|
||||
else:
|
||||
dim_prod *= shp_i
|
||||
|
||||
out = {'shape': shape_v,
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
||||
if not is_shape_unknown(x['shape']):
|
||||
out = self._update_shape_and_value(out, x, shape_v, dim_prod, neg_index)
|
||||
out = None
|
||||
if not is_shape_unknown(x.shape):
|
||||
x_shp = x.shape
|
||||
if dim_prod <= 0:
|
||||
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
|
||||
f"the value of 'input_shape' is {shape}. "
|
||||
f"The product of 'input_shape' should > 0, but got {dim_prod}.")
|
||||
arr_prod = np.prod(x_shp)
|
||||
if neg_index != -1:
|
||||
shape[neg_index] = int(arr_prod // dim_prod)
|
||||
dim_prod *= shape[neg_index]
|
||||
if dim_prod != arr_prod:
|
||||
raise ValueError(f"For '{self.name}', the product of the 'input_x' shape "
|
||||
f"should be equal to product of 'input_shape', but got product of the"
|
||||
f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.")
|
||||
out = Tensor(x.asnumpy().reshape(shape))
|
||||
return out
|
||||
|
||||
|
||||
|
@ -6018,7 +5980,6 @@ class EmbeddingLookup(PrimitiveWithCheck):
|
|||
raise ValueError(f"For '{self.name}', the dimension of 'input_params' must <= 2, "
|
||||
f"but got {len(params_shp)}.")
|
||||
|
||||
|
||||
class GatherD(Primitive):
|
||||
"""
|
||||
Gathers elements along an axis specified by dim.
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import ops as P
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class ReshapeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ReshapeNet, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, input_x, output_shape):
|
||||
return self.reshape(input_x, output_shape)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reshape_aicpu_dynamic_graph():
|
||||
"""
|
||||
Feature: Test Reshape operation for dynamic shape in graph mode.
|
||||
Description: The input shape is dynamic, the output_shape is random tuple, value is a tensor.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = ReshapeNet()
|
||||
input_shape_dyn = Tensor(shape=(3, None, 5), dtype=mstype.float64)
|
||||
out_shape_dyn = Tensor(np.random.randint(0, 5, size=3))
|
||||
net.set_inputs(input_shape_dyn, out_shape_dyn)
|
||||
input_shape = (3, 4, 5)
|
||||
out_shape = [2, 6, 5]
|
||||
input_x = np.random.random(input_shape)
|
||||
output = net(Tensor(input_x), Tensor(out_shape))
|
||||
expect = input_x.reshape(out_shape)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reshape_aicpu_dynamic_pynative():
|
||||
"""
|
||||
Feature: Test Reshape operation for dynamic shape in pynative mode.
|
||||
Description: The input shape is dynamic, the output_shape is random tuple, value is a tensor.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = ReshapeNet()
|
||||
input_shape_dyn = Tensor(shape=(3, None, 5), dtype=mstype.float64)
|
||||
out_shape_dyn = Tensor(np.random.randint(0, 5, size=3))
|
||||
net.set_inputs(input_shape_dyn, out_shape_dyn)
|
||||
input_shape = (3, 4, 5)
|
||||
out_shape = [2, 6, 5]
|
||||
input_x = np.random.random(input_shape)
|
||||
output = net(Tensor(input_x), Tensor(out_shape))
|
||||
expect = input_x.reshape(out_shape)
|
||||
assert np.all(output.asnumpy() == expect)
|
Loading…
Reference in New Issue