forked from mindspore-Ecosystem/mindspore
!26766 fix error log and move some function to inner
Merge pull request !26766 from lianliguang/master
This commit is contained in:
commit
04132e0c50
|
@ -227,7 +227,7 @@ AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValueP
|
|||
}
|
||||
|
||||
for (auto &elem : axis_data) {
|
||||
int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank) - 1);
|
||||
int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank));
|
||||
(void)axis_set.insert(e_value);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(x_shp_value->cast<ValueTuplePtr>());
|
||||
|
|
|
@ -1019,10 +1019,10 @@ bool SetMindIRGraphAction(const ResourcePtr &res) {
|
|||
if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
|
||||
MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before."
|
||||
<< "Please check the args is same with export.\n"
|
||||
<< "The export input argument size : " << func_args.size() << "\n"
|
||||
<< "The load input argument size : " << broaded_args.size() << "\n"
|
||||
<< "Export input args info:" << abstract::ArgsToString(func_args) << "\n"
|
||||
<< "The input args info:" << abstract::ArgsToString(broaded_args);
|
||||
<< "The export input argument size: " << func_args.size() << "\n"
|
||||
<< "The load input argument size: " << broaded_args.size() << "\n"
|
||||
<< "Export input args info: " << abstract::ArgsToString(func_args) << "\n"
|
||||
<< "The input args info: " << abstract::ArgsToString(broaded_args);
|
||||
}
|
||||
|
||||
// suppose that there is not KeywordArgument for the top graph
|
||||
|
|
|
@ -326,11 +326,10 @@ void AnalysisResultCacheMgr::Todo() {
|
|||
|
||||
std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
|
||||
std::ostringstream buffer;
|
||||
buffer << "(";
|
||||
for (const auto &item : args_spec_list) {
|
||||
buffer << item->ToString() << " # ";
|
||||
buffer << item->BuildType()->ToString() << "," << item->BuildShape()->ToString() << " #"
|
||||
<< "\n";
|
||||
}
|
||||
buffer << " )";
|
||||
return buffer.str();
|
||||
}
|
||||
} // namespace abstract
|
||||
|
|
|
@ -445,6 +445,36 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
|
||||
// As x and predicate both are mindspore type statically, here we only to judge whether
|
||||
// x is predicate or is a subclass of predicate.
|
||||
return IsIdentidityOrSubclass(x, expected_type);
|
||||
}
|
||||
|
||||
// Join all types in args_type_list;
|
||||
TypePtr TypeJoin(const TypePtrList &args_type_list) {
|
||||
if (args_type_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_type_list is empty";
|
||||
}
|
||||
|
||||
TypePtr type_tmp = args_type_list[0];
|
||||
for (std::size_t i = 1; i < args_type_list.size(); i++) {
|
||||
type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
|
||||
}
|
||||
return type_tmp;
|
||||
}
|
||||
|
||||
TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
|
||||
MS_EXCEPTION_IF_NULL(predicate);
|
||||
for (const auto &arg_type : args_type_list) {
|
||||
MS_EXCEPTION_IF_NULL(arg_type);
|
||||
if (!CheckType(predicate, arg_type)) {
|
||||
MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
|
||||
}
|
||||
}
|
||||
return TypeJoin(args_type_list);
|
||||
}
|
||||
} // end anonymous namespace
|
||||
|
||||
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
||||
|
|
|
@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -45,7 +45,7 @@ TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &e
|
|||
if (ok) {
|
||||
return type;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << error_message_prefix << accepts << " but is " << type->ToString();
|
||||
MS_EXCEPTION(TypeError) << error_message_prefix << " should be " << accepts << ",but got " << type->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,8 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty
|
|||
TypePtr sample_type = sample_elem->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(sample_type);
|
||||
std::ostringstream loginfoBuffer;
|
||||
loginfoBuffer << "same type, got";
|
||||
loginfoBuffer << "[" << sample_tensor->BuildType()->ToString();
|
||||
bool error_flag = false;
|
||||
// Check if other elements have the same type with the first element.
|
||||
for (size_t index = 1; index < tensor_list.size(); ++index) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list[index]);
|
||||
|
@ -87,12 +88,14 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty
|
|||
MS_EXCEPTION_IF_NULL(elem);
|
||||
auto a_type = elem->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(a_type);
|
||||
loginfoBuffer << " " << a_type->ToString();
|
||||
loginfoBuffer << "," << tensor_list[index]->BuildType()->ToString();
|
||||
if (sample_type->type_id() != a_type->type_id()) {
|
||||
MS_LOG(EXCEPTION) << "Expected type " << sample_type->ToString() << ", but got " << a_type->ToString()
|
||||
<< ", index " << index;
|
||||
error_flag = true;
|
||||
}
|
||||
}
|
||||
if (error_flag) {
|
||||
MS_EXCEPTION(ValueError) << error_message_prefix << " must be same, but got " << loginfoBuffer.str() << "]";
|
||||
}
|
||||
MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str();
|
||||
return CheckTensorDType(sample_tensor, accepts, error_message_prefix);
|
||||
}
|
||||
|
@ -167,15 +170,19 @@ int64_t CheckAxis(const std::string &op, const ValuePtr &axis, int64_t minimum,
|
|||
}
|
||||
int64_t axis_value = GetValue<int64_t>(axis);
|
||||
if (axis_value > max || axis_value < minimum) {
|
||||
MS_LOG(EXCEPTION) << op << " evaluator axis value should be in the range [" << minimum << ", " << max
|
||||
<< "], but get " << axis_value;
|
||||
MS_LOG(EXCEPTION) << "The primitive[" << op << "]'s axis value should be in the range [" << minimum << ", " << max
|
||||
<< "], but got " << axis_value;
|
||||
}
|
||||
if (axis_value < 0) {
|
||||
axis_value = axis_value + SizeToLong(max);
|
||||
}
|
||||
return axis_value;
|
||||
}
|
||||
void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list,
|
||||
size_t size_expect) {
|
||||
if (args_spec_list.size() != size_expect) {
|
||||
MS_LOG(EXCEPTION) << op << " input args size should be " << size_expect << ", but got " << args_spec_list.size();
|
||||
MS_LOG(EXCEPTION) << op << " input arguments size should be " << size_expect << ", but got "
|
||||
<< args_spec_list.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < size_expect; i++) {
|
||||
|
@ -200,65 +207,6 @@ void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
|
|||
}
|
||||
}
|
||||
|
||||
int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) {
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
auto int64_value = attr->cast<Int64ImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(int64_value);
|
||||
int64_t attr_val = int64_value->value();
|
||||
if (attr_val <= 0) {
|
||||
MS_LOG(EXCEPTION) << op << " invalid " << attr_name << " value: " << attr_val << ", should be greater then 0";
|
||||
}
|
||||
return attr_val;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx,
|
||||
const size_t num_element) {
|
||||
std::vector<int64_t> result;
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
if (attr->isa<ValueTuple>()) {
|
||||
auto tuple_attr = attr->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_attr);
|
||||
std::vector<ValuePtr> attr_vec = tuple_attr->value();
|
||||
if (start_idx > attr_vec.size() || start_idx + num_element > attr_vec.size()) {
|
||||
MS_EXCEPTION(IndexError) << op << " attr index is out of range, attr size is " << attr_vec.size()
|
||||
<< "but start idx got" << start_idx << " num element " << num_element;
|
||||
}
|
||||
auto it_start = attr_vec.begin() + start_idx;
|
||||
(void)std::transform(it_start, it_start + num_element, std::back_inserter(result),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
} else {
|
||||
auto int64_imm = attr->cast<Int64ImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(int64_imm);
|
||||
int64_t attr_val = int64_imm->value();
|
||||
(void)result.insert(result.begin(), num_element, attr_val);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name,
|
||||
const std::set<std::string> &val_set) {
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
auto string_attr = attr->cast<StringImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(string_attr);
|
||||
std::string attr_val = string_attr->value();
|
||||
if (val_set.find(attr_val) == val_set.end()) {
|
||||
std::ostringstream buffer;
|
||||
bool f_begin = true;
|
||||
buffer << "{";
|
||||
for (auto &x : val_set) {
|
||||
if (!f_begin) {
|
||||
buffer << ", ";
|
||||
} else {
|
||||
f_begin = false;
|
||||
}
|
||||
buffer << x;
|
||||
}
|
||||
buffer << "}";
|
||||
MS_LOG(EXCEPTION) << op << "Unsupported " << attr_name << ": " << attr_val << ". use " << buffer.str();
|
||||
}
|
||||
return attr_val;
|
||||
}
|
||||
|
||||
void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list,
|
||||
size_t size_expect) {
|
||||
if (args_spec_list.size() < size_expect) {
|
||||
|
@ -268,6 +216,5 @@ void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::Abs
|
|||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,8 +53,6 @@ void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape);
|
|||
|
||||
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape);
|
||||
|
||||
int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name);
|
||||
|
||||
std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx,
|
||||
const size_t num_element);
|
||||
|
||||
|
|
|
@ -149,8 +149,6 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
ValuePtr axis = primitive->GetAttr("axis");
|
||||
// Axis value should be in [-(rank_base + 1), rank_base).
|
||||
int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
|
||||
// If axis is negative, add offset(rank_base + 1) to turn it to positive.
|
||||
axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base + 1));
|
||||
|
||||
for (size_t i = 1; i < tuple_len; ++i) {
|
||||
AbstractTensorPtr tensor = nullptr;
|
||||
|
@ -950,8 +948,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
int64_t rank = SizeToLong(x_shape.size());
|
||||
|
||||
ValuePtr axis = primitive->GetAttr("axis");
|
||||
int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank);
|
||||
uint64_t axis_value_pos = LongToUlong(GetPositiveAxis(axis_value, LongToSize(rank)));
|
||||
int64_t axis_value_pos = CheckAxis(op_name, axis, -(rank + 1), rank);
|
||||
int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num"));
|
||||
if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) {
|
||||
MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos]
|
||||
|
@ -1097,8 +1094,6 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
ValuePtr axis = primitive->GetAttr("axis");
|
||||
// Axis value should be in [-(rank_base + 1), rank_base).
|
||||
int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
|
||||
// If axis is negative, add offset(rank_base) to turn it to positive.
|
||||
axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base));
|
||||
|
||||
int64_t all_shp = shape_base[axis_value];
|
||||
int64_t min_all_shp = min_shape_base[axis_value];
|
||||
|
|
|
@ -139,14 +139,14 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
|
||||
|
||||
auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
(void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param x of BatchNorm should be");
|
||||
(void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input argument[x] of BatchNorm");
|
||||
AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
|
||||
for (size_t i = 1; i < args_spec_list.size(); ++i) {
|
||||
auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
|
||||
tensorPtrList.push_back(param);
|
||||
}
|
||||
(void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32},
|
||||
"param gamma, beta, mean, variance of Batchnorm should be");
|
||||
"Input arguments[gamma, beta, mean, variance] of BatchNorm");
|
||||
|
||||
auto data_format_ptr = primitive->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
|
@ -240,113 +240,6 @@ void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const Ab
|
|||
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
constexpr auto kConv2DInputNum = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kConv2DInputNum);
|
||||
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
MS_EXCEPTION_IF_NULL(input_x->shape());
|
||||
ShapeVector x_shape = input_x->shape()->shape();
|
||||
ShapeVector x_min_shape = input_x->shape()->min_shape();
|
||||
ShapeVector x_max_shape = input_x->shape()->max_shape();
|
||||
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
|
||||
CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
|
||||
CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape);
|
||||
CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape);
|
||||
AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(input_w);
|
||||
MS_EXCEPTION_IF_NULL(input_w->shape());
|
||||
ShapeVector w_shape = input_w->shape()->shape();
|
||||
CheckShape(op_name, w_shape, input_w);
|
||||
const uint64_t n_axis = 0;
|
||||
uint64_t c_axis = 1;
|
||||
uint64_t h_axis = 2;
|
||||
uint64_t w_axis = 3;
|
||||
|
||||
int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format"));
|
||||
if (data_format == Format::NHWC) {
|
||||
c_axis = 3;
|
||||
h_axis = 1;
|
||||
w_axis = 2;
|
||||
}
|
||||
int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group");
|
||||
if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) &&
|
||||
((x_shape[c_axis] / group) != w_shape[c_axis])) {
|
||||
MS_LOG(EXCEPTION) << "x_shape[C_in] / group must be equal to w_shape[C_in]: " << w_shape[c_axis] << ", but got "
|
||||
<< (x_shape[c_axis] / group);
|
||||
}
|
||||
|
||||
int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel");
|
||||
if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {
|
||||
MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must be equal to " << out_channel;
|
||||
}
|
||||
|
||||
const size_t kernel_size_num_element = 2;
|
||||
std::vector<int64_t> kernel_size =
|
||||
CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, kernel_size_num_element);
|
||||
if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) {
|
||||
MS_LOG(EXCEPTION) << "weight height: " << w_shape[h_axis] << " must be equal to " << kernel_size[0];
|
||||
}
|
||||
if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) {
|
||||
MS_LOG(EXCEPTION) << "weight width: " << w_shape[w_axis] << " must be equal to " << kernel_size[1];
|
||||
}
|
||||
|
||||
std::vector<int64_t> stride =
|
||||
CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), stride_start_idx, stride_num_element);
|
||||
std::vector<int64_t> dilation =
|
||||
CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), dilation_start_idx, dilation_num_element);
|
||||
std::vector<int64_t> padding =
|
||||
CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), padding_start_idx, padding_num_element);
|
||||
int64_t pad_mode;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
|
||||
std::vector<int64_t> output_hw;
|
||||
std::vector<int64_t> pad_list;
|
||||
std::vector<int64_t> output_hw_min;
|
||||
std::vector<int64_t> pad_list_min;
|
||||
std::vector<int64_t> output_hw_max;
|
||||
std::vector<int64_t> pad_list_max;
|
||||
Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode,
|
||||
padding);
|
||||
if (x_shape[h_axis] == Shape::SHP_ANY) {
|
||||
output_hw[0] = Shape::SHP_ANY;
|
||||
}
|
||||
if (x_shape[w_axis] == Shape::SHP_ANY) {
|
||||
output_hw[1] = Shape::SHP_ANY;
|
||||
}
|
||||
Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride,
|
||||
dilation, pad_mode, padding);
|
||||
Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride,
|
||||
dilation, pad_mode, padding);
|
||||
std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
|
||||
MakeValue(pad_list[3])};
|
||||
primitive->set_attr("pad_list", MakeValue(pad_list_val));
|
||||
|
||||
ShapeVector output_shape;
|
||||
ShapeVector output_shape_min;
|
||||
ShapeVector output_shape_max;
|
||||
if (data_format == Format::NHWC) {
|
||||
output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel};
|
||||
output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel};
|
||||
output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel};
|
||||
} else {
|
||||
output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]};
|
||||
output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]};
|
||||
output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]};
|
||||
}
|
||||
CheckShapeAnyAndPositive(op_name + " output_shape", output_shape);
|
||||
CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min);
|
||||
CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max);
|
||||
|
||||
TypePtr x_type = input_x->element()->GetTypeTrack();
|
||||
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
|
||||
x_type = kInt32;
|
||||
}
|
||||
ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max);
|
||||
return std::make_shared<AbstractTensor>(x_type, output_shape_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: at least one tensor(y_backprop)
|
||||
|
|
|
@ -186,7 +186,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimPooling, R{InferImplPooling, nullptr, true}},
|
||||
{prim::kPrimPoolingGrad, R{InferImplPoolingGrad, nullptr, true}},
|
||||
{prim::kPrimBatchNorm, R{InferImplBatchNorm, nullptr, true}},
|
||||
{prim::kPrimConv2D, R{InferImplConv2D, nullptr, true}},
|
||||
{prim::kPrimBpropCut, R{InferImplBpropCut, nullptr, true}},
|
||||
{prim::kPrimDropout, R{InferImplDropout, nullptr, true}},
|
||||
{prim::kPrimSparseApplyFtrl, R{InferImplSparseApplyFtrl, nullptr, true}},
|
||||
|
|
|
@ -183,83 +183,6 @@ AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
|
|||
return spec->Clone();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Join all types in args_type_list;
|
||||
TypePtr TypeJoin(const TypePtrList &args_type_list) {
|
||||
if (args_type_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_type_list is empty";
|
||||
}
|
||||
|
||||
TypePtr type_tmp = args_type_list[0];
|
||||
for (std::size_t i = 1; i < args_type_list.size(); i++) {
|
||||
type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
|
||||
}
|
||||
return type_tmp;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
|
||||
// As x and predicate both are mindspore type statically, here we only to judge whether
|
||||
// x is predicate or is a subclass of predicate.
|
||||
return IsIdentidityOrSubclass(x, expected_type);
|
||||
}
|
||||
|
||||
TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
|
||||
MS_EXCEPTION_IF_NULL(predicate);
|
||||
for (const auto &arg_type : args_type_list) {
|
||||
MS_EXCEPTION_IF_NULL(arg_type);
|
||||
if (!CheckType(predicate, arg_type)) {
|
||||
MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
|
||||
}
|
||||
}
|
||||
return TypeJoin(args_type_list);
|
||||
}
|
||||
|
||||
int64_t GetPositiveAxis(int64_t axis_value, size_t increment) {
|
||||
if (axis_value < 0) {
|
||||
axis_value = axis_value + SizeToLong(increment);
|
||||
}
|
||||
|
||||
if (axis_value < 0) {
|
||||
MS_LOG(EXCEPTION) << "axis_value should not still <0";
|
||||
}
|
||||
|
||||
return axis_value;
|
||||
}
|
||||
|
||||
// Return if two shapes can be broadcast.
|
||||
// Broadcast shape is placed in broadcast_output_shape.
|
||||
ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) {
|
||||
std::reverse(x_shape.begin(), x_shape.end());
|
||||
std::reverse(y_shape.begin(), y_shape.end());
|
||||
// Fill a placeholder value 1 which will be replaced later.
|
||||
size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size();
|
||||
y_shape.resize(std_len, 1);
|
||||
x_shape.resize(std_len, 1);
|
||||
|
||||
ShapeVector broadcast_shape;
|
||||
for (size_t i = 0; i < std_len; i++) {
|
||||
int64_t x_i = x_shape[i]; // i-th dimension of x
|
||||
int64_t y_i = y_shape[i]; // i-th dimension of y
|
||||
int64_t output_i = 0; // i-th dimension of the output
|
||||
if (x_i == y_i) {
|
||||
output_i = x_i;
|
||||
} else if (x_i == 1) {
|
||||
output_i = y_i;
|
||||
} else if (y_i == 1) {
|
||||
output_i = x_i;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< op
|
||||
<< " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting "
|
||||
"requirements";
|
||||
}
|
||||
broadcast_shape.push_back(output_i);
|
||||
}
|
||||
std::reverse(broadcast_shape.begin(), broadcast_shape.end());
|
||||
return broadcast_shape;
|
||||
}
|
||||
|
||||
ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
|
||||
int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
|
||||
if (dlen < 0) {
|
||||
|
|
|
@ -43,20 +43,11 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac
|
|||
// else self.Clone;
|
||||
AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec);
|
||||
|
||||
TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list);
|
||||
|
||||
bool CheckType(const TypePtr &expected_type, const TypePtr &x);
|
||||
|
||||
int64_t GetPositiveAxis(int64_t axis_value, size_t increment);
|
||||
|
||||
ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy);
|
||||
|
||||
MS_CORE_API size_t TypeIdSize(const TypeId data_type);
|
||||
size_t ShapeSize(const std::vector<size_t> &shape);
|
||||
|
||||
// Get broadcasted shape for binary element-wise operation
|
||||
ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y);
|
||||
|
||||
// Check dynamic shape routine
|
||||
void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
|
||||
|
||||
|
|
|
@ -61,11 +61,10 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
|
||||
// begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1
|
||||
ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis");
|
||||
int64_t begin_norm_axis = abstract::CheckAxis(op_name, bna_ptr, -1, SizeToLong(input_rank) - 1);
|
||||
int64_t begin_norm_axis = abstract::CheckAxis(op_name, bna_ptr, -1, SizeToLong(input_rank));
|
||||
|
||||
ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis");
|
||||
int64_t begin_params_axis = abstract::CheckAxis(op_name, bpa_ptr, -1, SizeToLong(input_rank) - 1);
|
||||
begin_params_axis = abstract::GetPositiveAxis(begin_params_axis, input_rank);
|
||||
int64_t begin_params_axis = abstract::CheckAxis(op_name, bpa_ptr, -1, SizeToLong(input_rank));
|
||||
|
||||
// the beta and gama shape should be x_shape[begin_params_axis:]
|
||||
auto valid_types = {kFloat16, kFloat32};
|
||||
|
|
|
@ -104,7 +104,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
|
|||
engine_->Run(tupleSliceGraphPtr, args_spec_list);
|
||||
FAIL() << "Excepted exception :Args type is wrong";
|
||||
} catch (std::runtime_error const &err) {
|
||||
ASSERT_TRUE(std::string(err.what()).find("TupleSlice input args size should be 2, but got 3") != std::string::npos);
|
||||
ASSERT_TRUE(std::string(err.what()).find("TupleSlice input arguments size should be 2, but got 3") !=
|
||||
std::string::npos);
|
||||
} catch (...) {
|
||||
FAIL() << "Excepted exception :Args type is wrong";
|
||||
}
|
||||
|
@ -250,7 +251,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
|
|||
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
|
||||
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
|
||||
|
||||
auto fn_arg= std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
|
||||
auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||
AbstractBasePtrList eles;
|
||||
for (size_t i = 0; i < 6; i++) {
|
||||
|
|
Loading…
Reference in New Issue