!26766 fix error log and move some function to inner

Merge pull request !26766 from lianliguang/master
This commit is contained in:
i-robot 2021-11-26 06:09:42 +00:00 committed by Gitee
commit 04132e0c50
14 changed files with 60 additions and 287 deletions

View File

@ -227,7 +227,7 @@ AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValueP
}
for (auto &elem : axis_data) {
int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank) - 1);
int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank));
(void)axis_set.insert(e_value);
}
MS_EXCEPTION_IF_NULL(x_shp_value->cast<ValueTuplePtr>());

View File

@ -1019,10 +1019,10 @@ bool SetMindIRGraphAction(const ResourcePtr &res) {
if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before."
<< "Please check the args is same with export.\n"
<< "The export input argument size : " << func_args.size() << "\n"
<< "The load input argument size : " << broaded_args.size() << "\n"
<< "Export input args info:" << abstract::ArgsToString(func_args) << "\n"
<< "The input args info:" << abstract::ArgsToString(broaded_args);
<< "The export input argument size: " << func_args.size() << "\n"
<< "The load input argument size: " << broaded_args.size() << "\n"
<< "Export input args info: " << abstract::ArgsToString(func_args) << "\n"
<< "The input args info: " << abstract::ArgsToString(broaded_args);
}
// suppose that there is not KeywordArgument for the top graph

View File

@ -326,11 +326,10 @@ void AnalysisResultCacheMgr::Todo() {
std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
std::ostringstream buffer;
buffer << "(";
for (const auto &item : args_spec_list) {
buffer << item->ToString() << " # ";
buffer << item->BuildType()->ToString() << "," << item->BuildShape()->ToString() << " #"
<< "\n";
}
buffer << " )";
return buffer.str();
}
} // namespace abstract

View File

@ -445,6 +445,36 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *
}
}
}
bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
// As x and predicate both are mindspore type statically, here we only to judge whether
// x is predicate or is a subclass of predicate.
return IsIdentidityOrSubclass(x, expected_type);
}
// Join all types in args_type_list;
TypePtr TypeJoin(const TypePtrList &args_type_list) {
if (args_type_list.empty()) {
MS_LOG(EXCEPTION) << "args_type_list is empty";
}
TypePtr type_tmp = args_type_list[0];
for (std::size_t i = 1; i < args_type_list.size(); i++) {
type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
}
return type_tmp;
}
TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
MS_EXCEPTION_IF_NULL(predicate);
for (const auto &arg_type : args_type_list) {
MS_EXCEPTION_IF_NULL(arg_type);
if (!CheckType(predicate, arg_type)) {
MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
}
}
return TypeJoin(args_type_list);
}
} // end anonymous namespace
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {

View File

@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -45,7 +45,7 @@ TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &e
if (ok) {
return type;
} else {
MS_LOG(EXCEPTION) << error_message_prefix << accepts << " but is " << type->ToString();
MS_EXCEPTION(TypeError) << error_message_prefix << " should be " << accepts << ",but got " << type->ToString();
}
}
@ -79,7 +79,8 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty
TypePtr sample_type = sample_elem->BuildType();
MS_EXCEPTION_IF_NULL(sample_type);
std::ostringstream loginfoBuffer;
loginfoBuffer << "same type, got";
loginfoBuffer << "[" << sample_tensor->BuildType()->ToString();
bool error_flag = false;
// Check if other elements have the same type with the first element.
for (size_t index = 1; index < tensor_list.size(); ++index) {
MS_EXCEPTION_IF_NULL(tensor_list[index]);
@ -87,12 +88,14 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty
MS_EXCEPTION_IF_NULL(elem);
auto a_type = elem->BuildType();
MS_EXCEPTION_IF_NULL(a_type);
loginfoBuffer << " " << a_type->ToString();
loginfoBuffer << "," << tensor_list[index]->BuildType()->ToString();
if (sample_type->type_id() != a_type->type_id()) {
MS_LOG(EXCEPTION) << "Expected type " << sample_type->ToString() << ", but got " << a_type->ToString()
<< ", index " << index;
error_flag = true;
}
}
if (error_flag) {
MS_EXCEPTION(ValueError) << error_message_prefix << " must be same, but got " << loginfoBuffer.str() << "]";
}
MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str();
return CheckTensorDType(sample_tensor, accepts, error_message_prefix);
}
@ -167,15 +170,19 @@ int64_t CheckAxis(const std::string &op, const ValuePtr &axis, int64_t minimum,
}
int64_t axis_value = GetValue<int64_t>(axis);
if (axis_value > max || axis_value < minimum) {
MS_LOG(EXCEPTION) << op << " evaluator axis value should be in the range [" << minimum << ", " << max
<< "], but get " << axis_value;
MS_LOG(EXCEPTION) << "The primitive[" << op << "]'s axis value should be in the range [" << minimum << ", " << max
<< "], but got " << axis_value;
}
if (axis_value < 0) {
axis_value = axis_value + SizeToLong(max);
}
return axis_value;
}
void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list,
size_t size_expect) {
if (args_spec_list.size() != size_expect) {
MS_LOG(EXCEPTION) << op << " input args size should be " << size_expect << ", but got " << args_spec_list.size();
MS_LOG(EXCEPTION) << op << " input arguments size should be " << size_expect << ", but got "
<< args_spec_list.size();
}
for (size_t i = 0; i < size_expect; i++) {
@ -200,65 +207,6 @@ void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
}
}
int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) {
MS_EXCEPTION_IF_NULL(attr);
auto int64_value = attr->cast<Int64ImmPtr>();
MS_EXCEPTION_IF_NULL(int64_value);
int64_t attr_val = int64_value->value();
if (attr_val <= 0) {
MS_LOG(EXCEPTION) << op << " invalid " << attr_name << " value: " << attr_val << ", should be greater then 0";
}
return attr_val;
}
std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx,
const size_t num_element) {
std::vector<int64_t> result;
MS_EXCEPTION_IF_NULL(attr);
if (attr->isa<ValueTuple>()) {
auto tuple_attr = attr->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_attr);
std::vector<ValuePtr> attr_vec = tuple_attr->value();
if (start_idx > attr_vec.size() || start_idx + num_element > attr_vec.size()) {
MS_EXCEPTION(IndexError) << op << " attr index is out of range, attr size is " << attr_vec.size()
<< "but start idx got" << start_idx << " num element " << num_element;
}
auto it_start = attr_vec.begin() + start_idx;
(void)std::transform(it_start, it_start + num_element, std::back_inserter(result),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
} else {
auto int64_imm = attr->cast<Int64ImmPtr>();
MS_EXCEPTION_IF_NULL(int64_imm);
int64_t attr_val = int64_imm->value();
(void)result.insert(result.begin(), num_element, attr_val);
}
return result;
}
std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name,
const std::set<std::string> &val_set) {
MS_EXCEPTION_IF_NULL(attr);
auto string_attr = attr->cast<StringImmPtr>();
MS_EXCEPTION_IF_NULL(string_attr);
std::string attr_val = string_attr->value();
if (val_set.find(attr_val) == val_set.end()) {
std::ostringstream buffer;
bool f_begin = true;
buffer << "{";
for (auto &x : val_set) {
if (!f_begin) {
buffer << ", ";
} else {
f_begin = false;
}
buffer << x;
}
buffer << "}";
MS_LOG(EXCEPTION) << op << "Unsupported " << attr_name << ": " << attr_val << ". use " << buffer.str();
}
return attr_val;
}
void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list,
size_t size_expect) {
if (args_spec_list.size() < size_expect) {
@ -268,6 +216,5 @@ void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::Abs
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
}
}
} // namespace abstract
} // namespace mindspore

View File

@ -53,8 +53,6 @@ void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape);
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape);
int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name);
std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx,
const size_t num_element);

View File

@ -149,8 +149,6 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
ValuePtr axis = primitive->GetAttr("axis");
// Axis value should be in [-(rank_base + 1), rank_base).
int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
// If axis is negative, add offset(rank_base + 1) to turn it to positive.
axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base + 1));
for (size_t i = 1; i < tuple_len; ++i) {
AbstractTensorPtr tensor = nullptr;
@ -950,8 +948,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr
int64_t rank = SizeToLong(x_shape.size());
ValuePtr axis = primitive->GetAttr("axis");
int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank);
uint64_t axis_value_pos = LongToUlong(GetPositiveAxis(axis_value, LongToSize(rank)));
int64_t axis_value_pos = CheckAxis(op_name, axis, -(rank + 1), rank);
int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num"));
if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) {
MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos]
@ -1097,8 +1094,6 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
ValuePtr axis = primitive->GetAttr("axis");
// Axis value should be in [-(rank_base + 1), rank_base).
int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
// If axis is negative, add offset(rank_base) to turn it to positive.
axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base));
int64_t all_shp = shape_base[axis_value];
int64_t min_all_shp = min_shape_base[axis_value];

View File

@ -139,14 +139,14 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
(void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param x of BatchNorm should be");
(void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input argument[x] of BatchNorm");
AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
for (size_t i = 1; i < args_spec_list.size(); ++i) {
auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
tensorPtrList.push_back(param);
}
(void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32},
"param gamma, beta, mean, variance of Batchnorm should be");
"Input arguments[gamma, beta, mean, variance] of BatchNorm");
auto data_format_ptr = primitive->GetAttr("format");
MS_EXCEPTION_IF_NULL(data_format_ptr);
@ -240,113 +240,6 @@ void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const Ab
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
}
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
constexpr auto kConv2DInputNum = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kConv2DInputNum);
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());
ShapeVector x_shape = input_x->shape()->shape();
ShapeVector x_min_shape = input_x->shape()->min_shape();
ShapeVector x_max_shape = input_x->shape()->max_shape();
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape);
CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape);
AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(input_w);
MS_EXCEPTION_IF_NULL(input_w->shape());
ShapeVector w_shape = input_w->shape()->shape();
CheckShape(op_name, w_shape, input_w);
const uint64_t n_axis = 0;
uint64_t c_axis = 1;
uint64_t h_axis = 2;
uint64_t w_axis = 3;
int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format"));
if (data_format == Format::NHWC) {
c_axis = 3;
h_axis = 1;
w_axis = 2;
}
int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group");
if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) &&
((x_shape[c_axis] / group) != w_shape[c_axis])) {
MS_LOG(EXCEPTION) << "x_shape[C_in] / group must be equal to w_shape[C_in]: " << w_shape[c_axis] << ", but got "
<< (x_shape[c_axis] / group);
}
int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel");
if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {
MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must be equal to " << out_channel;
}
const size_t kernel_size_num_element = 2;
std::vector<int64_t> kernel_size =
CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, kernel_size_num_element);
if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) {
MS_LOG(EXCEPTION) << "weight height: " << w_shape[h_axis] << " must be equal to " << kernel_size[0];
}
if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) {
MS_LOG(EXCEPTION) << "weight width: " << w_shape[w_axis] << " must be equal to " << kernel_size[1];
}
std::vector<int64_t> stride =
CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), stride_start_idx, stride_num_element);
std::vector<int64_t> dilation =
CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), dilation_start_idx, dilation_num_element);
std::vector<int64_t> padding =
CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), padding_start_idx, padding_num_element);
int64_t pad_mode;
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
std::vector<int64_t> output_hw;
std::vector<int64_t> pad_list;
std::vector<int64_t> output_hw_min;
std::vector<int64_t> pad_list_min;
std::vector<int64_t> output_hw_max;
std::vector<int64_t> pad_list_max;
Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode,
padding);
if (x_shape[h_axis] == Shape::SHP_ANY) {
output_hw[0] = Shape::SHP_ANY;
}
if (x_shape[w_axis] == Shape::SHP_ANY) {
output_hw[1] = Shape::SHP_ANY;
}
Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride,
dilation, pad_mode, padding);
Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride,
dilation, pad_mode, padding);
std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
MakeValue(pad_list[3])};
primitive->set_attr("pad_list", MakeValue(pad_list_val));
ShapeVector output_shape;
ShapeVector output_shape_min;
ShapeVector output_shape_max;
if (data_format == Format::NHWC) {
output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel};
output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel};
output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel};
} else {
output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]};
output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]};
output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]};
}
CheckShapeAnyAndPositive(op_name + " output_shape", output_shape);
CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min);
CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max);
TypePtr x_type = input_x->element()->GetTypeTrack();
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
x_type = kInt32;
}
ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max);
return std::make_shared<AbstractTensor>(x_type, output_shape_ptr);
}
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: at least one tensor(y_backprop)

View File

@ -186,7 +186,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimPooling, R{InferImplPooling, nullptr, true}},
{prim::kPrimPoolingGrad, R{InferImplPoolingGrad, nullptr, true}},
{prim::kPrimBatchNorm, R{InferImplBatchNorm, nullptr, true}},
{prim::kPrimConv2D, R{InferImplConv2D, nullptr, true}},
{prim::kPrimBpropCut, R{InferImplBpropCut, nullptr, true}},
{prim::kPrimDropout, R{InferImplDropout, nullptr, true}},
{prim::kPrimSparseApplyFtrl, R{InferImplSparseApplyFtrl, nullptr, true}},

View File

@ -183,83 +183,6 @@ AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
return spec->Clone();
}
namespace {
// Join all types in args_type_list;
TypePtr TypeJoin(const TypePtrList &args_type_list) {
if (args_type_list.empty()) {
MS_LOG(EXCEPTION) << "args_type_list is empty";
}
TypePtr type_tmp = args_type_list[0];
for (std::size_t i = 1; i < args_type_list.size(); i++) {
type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
}
return type_tmp;
}
} // namespace
bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
// As x and predicate both are mindspore type statically, here we only to judge whether
// x is predicate or is a subclass of predicate.
return IsIdentidityOrSubclass(x, expected_type);
}
TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
MS_EXCEPTION_IF_NULL(predicate);
for (const auto &arg_type : args_type_list) {
MS_EXCEPTION_IF_NULL(arg_type);
if (!CheckType(predicate, arg_type)) {
MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
}
}
return TypeJoin(args_type_list);
}
int64_t GetPositiveAxis(int64_t axis_value, size_t increment) {
if (axis_value < 0) {
axis_value = axis_value + SizeToLong(increment);
}
if (axis_value < 0) {
MS_LOG(EXCEPTION) << "axis_value should not still <0";
}
return axis_value;
}
// Return if two shapes can be broadcast.
// Broadcast shape is placed in broadcast_output_shape.
ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) {
std::reverse(x_shape.begin(), x_shape.end());
std::reverse(y_shape.begin(), y_shape.end());
// Fill a placeholder value 1 which will be replaced later.
size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size();
y_shape.resize(std_len, 1);
x_shape.resize(std_len, 1);
ShapeVector broadcast_shape;
for (size_t i = 0; i < std_len; i++) {
int64_t x_i = x_shape[i]; // i-th dimension of x
int64_t y_i = y_shape[i]; // i-th dimension of y
int64_t output_i = 0; // i-th dimension of the output
if (x_i == y_i) {
output_i = x_i;
} else if (x_i == 1) {
output_i = y_i;
} else if (y_i == 1) {
output_i = x_i;
} else {
MS_LOG(EXCEPTION)
<< op
<< " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting "
"requirements";
}
broadcast_shape.push_back(output_i);
}
std::reverse(broadcast_shape.begin(), broadcast_shape.end());
return broadcast_shape;
}
ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
if (dlen < 0) {

View File

@ -43,20 +43,11 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac
// else self.Clone;
AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec);
TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list);
bool CheckType(const TypePtr &expected_type, const TypePtr &x);
int64_t GetPositiveAxis(int64_t axis_value, size_t increment);
ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy);
MS_CORE_API size_t TypeIdSize(const TypeId data_type);
size_t ShapeSize(const std::vector<size_t> &shape);
// Get broadcasted shape for binary element-wise operation
ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y);
// Check dynamic shape routine
void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);

View File

@ -61,11 +61,10 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
// begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1
ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis");
int64_t begin_norm_axis = abstract::CheckAxis(op_name, bna_ptr, -1, SizeToLong(input_rank) - 1);
int64_t begin_norm_axis = abstract::CheckAxis(op_name, bna_ptr, -1, SizeToLong(input_rank));
ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis");
int64_t begin_params_axis = abstract::CheckAxis(op_name, bpa_ptr, -1, SizeToLong(input_rank) - 1);
begin_params_axis = abstract::GetPositiveAxis(begin_params_axis, input_rank);
int64_t begin_params_axis = abstract::CheckAxis(op_name, bpa_ptr, -1, SizeToLong(input_rank));
// the beta and gama shape should be x_shape[begin_params_axis:]
auto valid_types = {kFloat16, kFloat32};

View File

@ -104,7 +104,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
engine_->Run(tupleSliceGraphPtr, args_spec_list);
FAIL() << "Excepted exception :Args type is wrong";
} catch (std::runtime_error const &err) {
ASSERT_TRUE(std::string(err.what()).find("TupleSlice input args size should be 2, but got 3") != std::string::npos);
ASSERT_TRUE(std::string(err.what()).find("TupleSlice input arguments size should be 2, but got 3") !=
std::string::npos);
} catch (...) {
FAIL() << "Excepted exception :Args type is wrong";
}
@ -250,7 +251,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
auto fn_arg= std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
AbstractBasePtrList eles;
for (size_t i = 0; i < 6; i++) {