forked from mindspore-Ecosystem/mindspore
!26922 modify error info of conv2d bprop
Merge pull request !26922 from wangnan39/bug_fix_conv_x_size_error
This commit is contained in:
commit
d74ad572d9
|
@ -848,9 +848,9 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
|
||||||
if (args_spec_list.size() == 2) {
|
if (args_spec_list.size() == 2) {
|
||||||
auto input_value = args_spec_list[1]->BuildValue();
|
auto input_value = args_spec_list[1]->BuildValue();
|
||||||
if (input_value->isa<tensor::Tensor>()) {
|
if (input_value->isa<tensor::Tensor>()) {
|
||||||
shape = CheckAndConvertUtils::CheckTensorIntValue("reshape args value", input_value, op_name);
|
shape = CheckAndConvertUtils::CheckTensorIntValue("shape", input_value, op_name);
|
||||||
} else {
|
} else {
|
||||||
shape = CheckAndConvertUtils::CheckAttrTupleInt("reshape args value", input_value, op_name);
|
shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", input_value, op_name);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ValuePtr sh = primitive->GetAttr("shape");
|
ValuePtr sh = primitive->GetAttr("shape");
|
||||||
|
|
|
@ -99,11 +99,13 @@ abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
|
||||||
}
|
}
|
||||||
} else if (filter_size->isa<abstract::AbstractTuple>()) {
|
} else if (filter_size->isa<abstract::AbstractTuple>()) {
|
||||||
// check tensor, tuple or int to raise error.
|
// check tensor, tuple or int to raise error.
|
||||||
out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("filter_size", filter_size_v, prim_name);
|
out_shape = CheckAndConvertUtils::CheckTupleInt("input[filter_size]", filter_size_v, prim_name);
|
||||||
ret_shape = std::make_shared<abstract::Shape>(out_shape);
|
ret_shape = std::make_shared<abstract::Shape>(out_shape);
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION(TypeError) << "Conv2DBackpropFilter filter_size must be a tuple or tensor, but "
|
auto size_type = filter_size->BuildType();
|
||||||
<< filter_size->ToString();
|
MS_EXCEPTION_IF_NULL(size_type);
|
||||||
|
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input[filter size] must be a tuple or Tensor, "
|
||||||
|
<< "but got " << size_type->ToString();
|
||||||
}
|
}
|
||||||
return ret_shape;
|
return ret_shape;
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,9 +35,10 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
// check
|
// check
|
||||||
auto kernel_size =
|
auto kernel_size =
|
||||||
CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name);
|
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[kernel_size]", primitive->GetAttr(kKernelSize), prim_name);
|
||||||
auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name);
|
auto stride = CheckAndConvertUtils::CheckIntOrTupleInt("attribute[stride]", primitive->GetAttr(kStride), prim_name);
|
||||||
auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
|
auto dilation =
|
||||||
|
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
|
||||||
// default pad mode is valid
|
// default pad mode is valid
|
||||||
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
||||||
int64_t pad_mode;
|
int64_t pad_mode;
|
||||||
|
@ -128,10 +129,13 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
|
||||||
}
|
}
|
||||||
} else if (input_size->isa<abstract::AbstractTuple>()) {
|
} else if (input_size->isa<abstract::AbstractTuple>()) {
|
||||||
// check tensor, tuple or int to raise error.
|
// check tensor, tuple or int to raise error.
|
||||||
out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("input x size", input_size_v, prim_name);
|
out_shape = CheckAndConvertUtils::CheckTupleInt("input[x size]", input_size_v, prim_name);
|
||||||
ret_shape = std::make_shared<abstract::Shape>(out_shape);
|
ret_shape = std::make_shared<abstract::Shape>(out_shape);
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION(TypeError) << "Conv2DBackpropInput x_size must be a tuple or tensor, but " << input_size->ToString();
|
auto size_type = input_size->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(size_type);
|
||||||
|
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input[x size] must be a tuple or Tensor, "
|
||||||
|
<< "but got " << size_type->ToString();
|
||||||
}
|
}
|
||||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDoutIndex]->BuildShape())[kShape];
|
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDoutIndex]->BuildShape())[kShape];
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
||||||
void CheckSliceType(const AbstractBasePtr &input_arg, const std::string &arg_name, const std::string &prim_name) {
|
void CheckSliceType(const AbstractBasePtr &input_arg, const std::string &arg_name, const std::string &prim_name) {
|
||||||
if (input_arg->isa<abstract::AbstractTuple>()) {
|
if (input_arg->isa<abstract::AbstractTuple>()) {
|
||||||
auto temp_value = input_arg->BuildValue();
|
auto temp_value = input_arg->BuildValue();
|
||||||
(void)CheckAndConvertUtils::CheckAttrTupleInt(arg_name, temp_value, prim_name);
|
(void)CheckAndConvertUtils::CheckTupleInt(arg_name, temp_value, prim_name);
|
||||||
return;
|
return;
|
||||||
} else if (input_arg->isa<abstract::AbstractTensor>()) {
|
} else if (input_arg->isa<abstract::AbstractTensor>()) {
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, prim_name);
|
(void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, prim_name);
|
||||||
|
@ -52,7 +52,7 @@ abstract::ShapePtr StridedSliceGradInferShape(const PrimitivePtr &primitive,
|
||||||
abstract::ShapePtr ret_shape;
|
abstract::ShapePtr ret_shape;
|
||||||
|
|
||||||
if (shapex->isa<abstract::AbstractTuple>()) {
|
if (shapex->isa<abstract::AbstractTuple>()) {
|
||||||
out_shape = CheckAndConvertUtils::CheckAttrTupleInt("shapex", shape_value, prim_name);
|
out_shape = CheckAndConvertUtils::CheckTupleInt("input[shapex]", shape_value, prim_name);
|
||||||
return std::make_shared<abstract::Shape>(out_shape);
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ abstract::ShapePtr OnesInferShape(const PrimitivePtr &primitive, const std::vect
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
// check
|
// check
|
||||||
auto shape_value = input_args[0]->BuildValue();
|
auto shape_value = input_args[0]->BuildValue();
|
||||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name);
|
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
||||||
return std::make_shared<abstract::Shape>(out_shape);
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
|
||||||
if (input_value->isa<tensor::Tensor>()) {
|
if (input_value->isa<tensor::Tensor>()) {
|
||||||
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, prim_name);
|
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, prim_name);
|
||||||
} else {
|
} else {
|
||||||
tmp_input = CheckAndConvertUtils::CheckAttrTupleInt("slice args value", input_value, prim_name);
|
tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, prim_name);
|
||||||
}
|
}
|
||||||
(void)input_values.emplace_back(tmp_input);
|
(void)input_values.emplace_back(tmp_input);
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,12 +69,9 @@ void GetAndCheckAttrMask(const PrimitivePtr &primitive, std::vector<int64_t> *be
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
|
int64_t GetSlicingLengthForPositiveStrides(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
|
||||||
int64_t slicing_length = 0;
|
int64_t slicing_length = 0;
|
||||||
if (strides > 0) {
|
if ((start_pos < x_dim) && end_pos >= -x_dim) {
|
||||||
if ((start_pos >= x_dim) || end_pos < -x_dim) {
|
|
||||||
slicing_length = 0;
|
|
||||||
} else {
|
|
||||||
if (-x_dim <= start_pos && start_pos < 0) {
|
if (-x_dim <= start_pos && start_pos < 0) {
|
||||||
start_pos += x_dim;
|
start_pos += x_dim;
|
||||||
}
|
}
|
||||||
|
@ -93,10 +90,12 @@ int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides
|
||||||
slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
|
slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
return slicing_length;
|
||||||
if (start_pos < -x_dim || end_pos >= x_dim) {
|
}
|
||||||
slicing_length = 0;
|
|
||||||
} else {
|
int64_t GetSlicingLengthForNegativeStrides(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
|
||||||
|
int64_t slicing_length = 0;
|
||||||
|
if (start_pos >= -x_dim && end_pos < x_dim) {
|
||||||
if (start_pos > 0 && start_pos < x_dim) {
|
if (start_pos > 0 && start_pos < x_dim) {
|
||||||
start_pos += -x_dim;
|
start_pos += -x_dim;
|
||||||
}
|
}
|
||||||
|
@ -115,6 +114,18 @@ int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides
|
||||||
slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
|
slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return slicing_length;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t ComputeSlicingLength(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
|
||||||
|
int64_t slicing_length = 0;
|
||||||
|
if (strides == 0) {
|
||||||
|
MS_EXCEPTION(ValueError) << "StridedSlice's input strides cannot contain 0.";
|
||||||
|
}
|
||||||
|
if (strides > 0) {
|
||||||
|
slicing_length = GetSlicingLengthForPositiveStrides(start_pos, end_pos, strides, x_dim);
|
||||||
|
} else {
|
||||||
|
slicing_length = GetSlicingLengthForNegativeStrides(start_pos, end_pos, strides, x_dim);
|
||||||
}
|
}
|
||||||
return slicing_length;
|
return slicing_length;
|
||||||
}
|
}
|
||||||
|
@ -288,7 +299,7 @@ bool CheckAndGetDynamicSlice(const AbstractBasePtr &input_arg, const std::string
|
||||||
auto input_value = input_arg->BuildValue();
|
auto input_value = input_arg->BuildValue();
|
||||||
MS_EXCEPTION_IF_NULL(input_value);
|
MS_EXCEPTION_IF_NULL(input_value);
|
||||||
if (input_arg->isa<abstract::AbstractTuple>()) {
|
if (input_arg->isa<abstract::AbstractTuple>()) {
|
||||||
*slice_value = CheckAndConvertUtils::CheckAttrTupleInt(arg_name, input_value, "StridedSlice");
|
*slice_value = CheckAndConvertUtils::CheckTupleInt(arg_name, input_value, "StridedSlice");
|
||||||
*slice_len = (*slice_value).size();
|
*slice_len = (*slice_value).size();
|
||||||
} else if (input_arg->isa<abstract::AbstractTensor>()) {
|
} else if (input_arg->isa<abstract::AbstractTensor>()) {
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, "StridedSlice");
|
(void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->BuildType(), {kInt64}, "StridedSlice");
|
||||||
|
|
|
@ -64,9 +64,9 @@ abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vect
|
||||||
std::vector<int64_t> multiples_v;
|
std::vector<int64_t> multiples_v;
|
||||||
auto multiple_value = input_args[1]->BuildValue();
|
auto multiple_value = input_args[1]->BuildValue();
|
||||||
if (multiple_value->isa<tensor::Tensor>()) {
|
if (multiple_value->isa<tensor::Tensor>()) {
|
||||||
multiples_v = CheckAndConvertUtils::CheckTensorIntValue("tile multiples value", multiple_value, prim_name);
|
multiples_v = CheckAndConvertUtils::CheckTensorIntValue("multiples", multiple_value, prim_name);
|
||||||
} else {
|
} else {
|
||||||
multiples_v = CheckAndConvertUtils::CheckAttrTupleInt("tile multiples value", multiple_value, prim_name);
|
multiples_v = CheckAndConvertUtils::CheckTupleInt("input[multiples]", multiple_value, prim_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto multiple : multiples_v) {
|
for (auto multiple : multiples_v) {
|
||||||
|
|
|
@ -46,9 +46,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
auto perm_value = input_args[1]->BuildValue();
|
auto perm_value = input_args[1]->BuildValue();
|
||||||
MS_EXCEPTION_IF_NULL(perm_value);
|
MS_EXCEPTION_IF_NULL(perm_value);
|
||||||
if (perm_value->isa<tensor::Tensor>()) {
|
if (perm_value->isa<tensor::Tensor>()) {
|
||||||
p_value = CheckAndConvertUtils::CheckTensorIntValue("perm value", perm_value, op_name);
|
p_value = CheckAndConvertUtils::CheckTensorIntValue("perm", perm_value, op_name);
|
||||||
} else {
|
} else {
|
||||||
p_value = CheckAndConvertUtils::CheckAttrTupleInt("perm value", perm_value, op_name);
|
p_value = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (x_shape.size() != p_value.size()) {
|
if (x_shape.size() != p_value.size()) {
|
||||||
|
|
|
@ -31,7 +31,7 @@ abstract::ShapePtr ZerosInferShape(const PrimitivePtr &primitive, const std::vec
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
// check
|
// check
|
||||||
auto shape_value = input_args[0]->BuildValue();
|
auto shape_value = input_args[0]->BuildValue();
|
||||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name);
|
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
||||||
return std::make_shared<abstract::Shape>(out_shape);
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -546,7 +546,7 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
|
||||||
ShapeVector tensor_value;
|
ShapeVector tensor_value;
|
||||||
if (!value->isa<tensor::Tensor>()) {
|
if (!value->isa<tensor::Tensor>()) {
|
||||||
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name
|
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name
|
||||||
<< "] must be a tensor,but got " << value->ToString();
|
<< "] must be a tensor, but got " << value->ToString();
|
||||||
}
|
}
|
||||||
auto input_tensor = value->cast<tensor::TensorPtr>();
|
auto input_tensor = value->cast<tensor::TensorPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||||
|
@ -565,7 +565,7 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
|
||||||
tensor_value = {tensor_data, tensor_data + data_size};
|
tensor_value = {tensor_data, tensor_data + data_size};
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "] input argument[" << type_name
|
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "] input argument[" << type_name
|
||||||
<< "] must be a Tensor[Int64] or Tensor[Int32] type,but got " << value->ToString();
|
<< "] must be a Tensor[Int64] or Tensor[Int32] type, but got " << value->ToString();
|
||||||
}
|
}
|
||||||
return tensor_value;
|
return tensor_value;
|
||||||
}
|
}
|
||||||
|
@ -726,7 +726,7 @@ void CheckAndConvertUtils::CheckMode(const std::string &class_name) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
std::vector<int64_t> CheckAndConvertUtils::CheckIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
||||||
const std::string &prim_name) {
|
const std::string &prim_name) {
|
||||||
std::vector<int64_t> result;
|
std::vector<int64_t> result;
|
||||||
bool is_correct = false;
|
bool is_correct = false;
|
||||||
|
@ -749,13 +749,13 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::str
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!is_correct) {
|
if (!is_correct) {
|
||||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
|
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
|
||||||
<< "] must be a Int or a tuple with all Int elements, but got " << attr->ToString();
|
<< " must be a Int or a tuple with all Int elements, but got " << attr->ToString();
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
std::vector<int64_t> CheckAndConvertUtils::CheckTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
||||||
const std::string &prim_name) {
|
const std::string &prim_name) {
|
||||||
std::vector<int64_t> result;
|
std::vector<int64_t> result;
|
||||||
MS_EXCEPTION_IF_NULL(attr);
|
MS_EXCEPTION_IF_NULL(attr);
|
||||||
|
@ -764,14 +764,14 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &
|
||||||
(void)std::transform(
|
(void)std::transform(
|
||||||
attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
|
attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
|
||||||
if (!e->isa<Int64Imm>()) {
|
if (!e->isa<Int64Imm>()) {
|
||||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
|
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
|
||||||
<< "] must be a tuple with all Int elements, but got " << attr->ToString();
|
<< " must be a tuple with all Int elements, but got " << attr->ToString();
|
||||||
}
|
}
|
||||||
return GetValue<int64_t>(e);
|
return GetValue<int64_t>(e);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
|
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s " << arg_name
|
||||||
<< "] must be a tuple with all Int elements, but got " << attr->ToString() << ".";
|
<< " must be a tuple with all Int elements, but got " << attr->ToString() << ".";
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -306,9 +306,9 @@ class CheckAndConvertUtils {
|
||||||
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
|
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
|
||||||
const std::string &class_name);
|
const std::string &class_name);
|
||||||
static void CheckMode(const std::string &class_name);
|
static void CheckMode(const std::string &class_name);
|
||||||
static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
static std::vector<int64_t> CheckIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
||||||
const std::string &arg_name);
|
const std::string &arg_name);
|
||||||
static std::vector<int64_t> CheckAttrTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
static std::vector<int64_t> CheckTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
||||||
const std::string &arg_name);
|
const std::string &arg_name);
|
||||||
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
|
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
|
||||||
static int64_t GetAndCheckFormat(const ValuePtr &value);
|
static int64_t GetAndCheckFormat(const ValuePtr &value);
|
||||||
|
|
|
@ -34,7 +34,7 @@ def _check_mul():
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
print(f"MindSpore version:", ms.__version__)
|
print(f"MindSpore version: ", ms.__version__)
|
||||||
|
|
||||||
input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
|
input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
|
||||||
input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32)
|
input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32)
|
||||||
|
|
Loading…
Reference in New Issue