!16497 Add more accurate abstract join error information

From: @huangbingjian
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-05-18 19:36:11 +08:00 committed by Gitee
commit ed72a0d9c3
5 changed files with 74 additions and 43 deletions

View File

@ -240,13 +240,14 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status,
const OptimizerPtr &optimizer, size_t space) const {
constexpr int pad_width = 4;
std::stringstream ss;
ss << std::endl
<< "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name
<< std::endl;
for (size_t i = 0; i < list_.size(); i++) {
auto name = list_[i]->name_;
ss << std::left << std::setw(SizeToInt(space) + 4) << name << "\t";
ss << std::left << std::setw(SizeToInt(space) + pad_width) << name << "\t";
for (auto change : status.at(name + std::to_string(i))) {
ss << change << " ";
}

View File

@ -49,7 +49,9 @@ bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
return arg1->Join(arg2);
auto abstract = arg1->Join(arg2);
MS_EXCEPTION_IF_NULL(abstract);
return abstract;
}
return nullptr;
}
@ -661,6 +663,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
AbstractBasePtrList out_specs;
EvaluatorPtr last_eval = nullptr;
AbstractBasePtr last_abstract = nullptr;
multi_poss_[evaluators[0]] = evaluators[1];
multi_poss_[evaluators[1]] = evaluators[0];
AbstractBasePtrList args_spec_list;
@ -680,8 +684,21 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
eval_trace_.push_back(current_inf);
MS_EXCEPTION_IF_NULL(eval);
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
out_specs.push_back(eval_result->abstract());
auto eval_abstract = eval_result->abstract();
MS_EXCEPTION_IF_NULL(eval_abstract);
if (last_abstract != nullptr && eval_abstract->Join(last_abstract) == nullptr) {
auto node = out_conf->node();
MS_LOG(EXCEPTION) << "Abstracts cannot be joined! Please check the data type of node : " << node->DebugString()
<< ".\nThe current evaluator is " << eval->ToString() << " with abstract "
<< eval_abstract->ToString() << ", and the previous evaluator is " << last_eval->ToString()
<< " with abstract " << last_abstract->ToString() << trace::DumpSourceLines(node);
} else {
last_abstract = eval_abstract;
last_eval = eval;
}
out_specs.push_back(eval_abstract);
eval_trace_.pop_back();
if (eval_trace_.empty()) {
multi_poss_.clear();

View File

@ -106,8 +106,9 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(value_self);
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
if (res_type == kAnyType) {
MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << GetTypeTrack()->ToString()
<< ", type2 = " << other->GetTypeTrack()->ToString();
MS_LOG(ERROR) << "Type join failed, type1 = " << GetTypeTrack()->ToString()
<< ", type2 = " << other->GetTypeTrack()->ToString();
return nullptr;
}
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
if (res_value == value_self) {
@ -460,6 +461,9 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
if (other->BuildType()->type_id() == kObjectTypeUndeterminedType) {
auto other_tensor = dyn_cast<AbstractUndetermined>(other);
auto element = element_->Join(other_tensor->element());
if (element == nullptr) {
return nullptr;
}
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractUndetermined>(element, shape);
return ret;
@ -472,6 +476,9 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
return shared_from_base<AbstractBase>();
}
auto element = element_->Join(other_tensor->element_);
if (element == nullptr) {
return nullptr;
}
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
return std::make_shared<AbstractTensor>(element, shape);
}
@ -881,6 +888,7 @@ AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
}
auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = AbstractTensor::Join(other_ref->ref())->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(ref);
return std::make_shared<AbstractRef>(ref_key, ref);
}

View File

@ -50,41 +50,7 @@ TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
return kAnyType;
}
ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
MS_EXCEPTION_IF_NULL(shape1);
MS_EXCEPTION_IF_NULL(shape2);
if (*shape1 == *shape2) {
return shape1;
}
// lengths of two shapes are not same, join failed
if (shape1->shape().size() != shape2->shape().size()) {
// special case: shape(1), shape() -> shape(1)
if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) {
return shape1;
}
if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
return shape2;
}
MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString()
<< ", shape2 = " << shape2->ToString();
}
ShapeVector dims;
bool has_dynamic_shape = false;
dims.resize(shape1->shape().size());
for (std::size_t i = 0; i < shape1->shape().size(); i++) {
if (shape1->shape()[i] == shape2->shape()[i]) {
dims[i] = shape1->shape()[i];
if (shape1->shape()[i] == Shape::SHP_ANY) {
has_dynamic_shape = true;
}
} else {
dims[i] = Shape::SHP_ANY;
has_dynamic_shape = true;
}
}
if (!has_dynamic_shape) {
return std::make_shared<Shape>(dims);
}
ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, const ShapeVector &dims) {
// calculate dynamic shape
ShapeVector min_dims(dims.size());
ShapeVector max_dims(dims.size());
@ -131,6 +97,44 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
return std::make_shared<Shape>(dims, min_dims, max_dims);
}
ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
MS_EXCEPTION_IF_NULL(shape1);
MS_EXCEPTION_IF_NULL(shape2);
if (*shape1 == *shape2) {
return shape1;
}
// lengths of two shapes are not same, join failed
if (shape1->shape().size() != shape2->shape().size()) {
// special case: shape(1), shape() -> shape(1)
if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) {
return shape1;
}
if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
return shape2;
}
MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString()
<< ", shape2 = " << shape2->ToString();
}
ShapeVector dims;
bool has_dynamic_shape = false;
dims.resize(shape1->shape().size());
for (std::size_t i = 0; i < shape1->shape().size(); i++) {
if (shape1->shape()[i] == shape2->shape()[i]) {
dims[i] = shape1->shape()[i];
if (shape1->shape()[i] == Shape::SHP_ANY) {
has_dynamic_shape = true;
}
} else {
dims[i] = Shape::SHP_ANY;
has_dynamic_shape = true;
}
}
if (!has_dynamic_shape) {
return std::make_shared<Shape>(dims);
}
return CalculateDynamicShape(shape1, shape2, dims);
}
AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() < 1) {
MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size()
@ -154,6 +158,7 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac
bool changes = false;
for (std::size_t i = 0; i < spec1.size(); i++) {
auto joined_elem = spec1[i]->Join(spec2[i]);
MS_EXCEPTION_IF_NULL(joined_elem);
if (joined_elem != spec1[i]) {
changes = true;
}
@ -189,7 +194,7 @@ TypePtr TypeJoin(const TypePtrList &args_type_list) {
} // namespace
bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
// As x and predicate both are mindspore type staticly, here we only to judge whether
// As x and predicate both are mindspore type statically, here we only to judge whether
// x is predicate or is a subclass of predicate.
return IsIdentidityOrSubclass(x, expected_type);
}

View File

@ -889,7 +889,7 @@ def test_switch_layer_dtype_join_failed():
inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(0, mstype.int32)
with pytest.raises(TypeError) as err:
with pytest.raises(Exception) as err:
net(i, inp)