forked from mindspore-Ecosystem/mindspore
modify the error message of abstract join
This commit is contained in:
parent
4d984fab6a
commit
d695554ae5
|
@ -35,6 +35,7 @@
|
|||
#include "debug/common.h"
|
||||
#include "pipeline/jit/static_analysis/evaluator.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support debug trace information
|
||||
|
@ -605,5 +606,20 @@ struct TraceProviderRegister {
|
|||
}
|
||||
~TraceProviderRegister() = default;
|
||||
} trace_provider_regsiter;
|
||||
|
||||
// Register trace cnode provider to AbstractBase.
|
||||
struct TraceNodeProviderRegister {
|
||||
TraceNodeProviderRegister() {
|
||||
abstract::AbstractBase::set_trace_node_provider([](AnfNodePtr *node) {
|
||||
auto stack = GetCNodeDebugStack();
|
||||
if (!stack.empty()) {
|
||||
auto conf = stack.back();
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
*node = conf->node();
|
||||
}
|
||||
});
|
||||
}
|
||||
~TraceNodeProviderRegister() = default;
|
||||
} trace_node_provider_regsiter;
|
||||
} // namespace trace
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,9 +49,7 @@ bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
|
|||
|
||||
AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
|
||||
if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
|
||||
auto abstract = arg1->Join(arg2);
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
return abstract;
|
||||
return arg1->Join(arg2);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -644,7 +642,39 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt
|
|||
return latest_entry;
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) {
|
||||
std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBasePtr &last_spec,
|
||||
const AnfNodePtr &node, const std::string &error_info) {
|
||||
std::ostringstream buffer;
|
||||
buffer << "The return values of different branches do not match. " << error_info
|
||||
<< " The abstract type of the return value of the current branch is " << spec->ToString()
|
||||
<< ", and that of the previous branch is " << last_spec->ToString() << ". Please check the node "
|
||||
<< node->DebugString();
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>()->input(0);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
|
||||
// {prim::kPrimSwitch, cond, true_branch, false_branch}
|
||||
constexpr int true_index = 2;
|
||||
constexpr int false_index = 3;
|
||||
auto inputs = cnode->cast<CNodePtr>()->inputs();
|
||||
buffer << ", true branch: " << inputs.at(true_index)->ToString()
|
||||
<< ", false branch: " << inputs.at(false_index)->ToString();
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
|
||||
// {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}}
|
||||
constexpr int branch_index = 2;
|
||||
auto tuple_node = cnode->cast<CNodePtr>()->input(branch_index);
|
||||
if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
|
||||
auto tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < tuple_inputs.size(); i++) {
|
||||
buffer << ", branch" << i << ": " << tuple_inputs.at(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buffer << ". trace: " << trace::DumpSourceLines(node);
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node) {
|
||||
if (out_specs.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
|
||||
}
|
||||
|
@ -654,8 +684,28 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
|
|||
// If only one result derived, then broaden it to avoid wrong constant propagation.
|
||||
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
auto joined_spec = AbstractJoin(out_specs);
|
||||
MS_EXCEPTION_IF_NULL(joined_spec);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
AbstractBasePtr last_spec = out_specs[0];
|
||||
AbstractBasePtr joined_spec = out_specs[0];
|
||||
for (const auto &spec : out_specs) {
|
||||
MS_EXCEPTION_IF_NULL(spec);
|
||||
try {
|
||||
joined_spec = joined_spec->Join(spec);
|
||||
} catch (const py::type_error &ex) {
|
||||
auto error_info = ExtractLoggingInfo(ex.what());
|
||||
MS_EXCEPTION(TypeError) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
|
||||
} catch (const py::value_error &ex) {
|
||||
auto error_info = ExtractLoggingInfo(ex.what());
|
||||
MS_EXCEPTION(ValueError) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
|
||||
} catch (const std::exception &ex) {
|
||||
auto error_info = ExtractLoggingInfo(ex.what());
|
||||
MS_LOG(EXCEPTION) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(joined_spec);
|
||||
last_spec = spec;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
|
||||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
@ -664,8 +714,6 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
AbstractBasePtrList out_specs;
|
||||
EvaluatorPtr last_eval = nullptr;
|
||||
AbstractBasePtr last_abstract = nullptr;
|
||||
const size_t evaluators_size = 2;
|
||||
if (evaluators.size() < evaluators_size) {
|
||||
MS_LOG(ERROR) << "evaluators size is less than 2";
|
||||
|
@ -692,17 +740,6 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
auto eval_abstract = eval_result->abstract();
|
||||
MS_EXCEPTION_IF_NULL(eval_abstract);
|
||||
|
||||
if (last_abstract != nullptr && eval_abstract->Join(last_abstract) == nullptr) {
|
||||
auto node = out_conf->node();
|
||||
MS_LOG(EXCEPTION) << "Abstracts cannot be joined! Please check the data type of node : " << node->DebugString()
|
||||
<< ".\nThe current evaluator is " << eval->ToString() << " with abstract "
|
||||
<< eval_abstract->ToString() << ", and the previous evaluator is " << last_eval->ToString()
|
||||
<< " with abstract " << last_abstract->ToString() << trace::DumpSourceLines(node);
|
||||
} else {
|
||||
last_abstract = eval_abstract;
|
||||
last_eval = eval;
|
||||
}
|
||||
|
||||
out_specs.push_back(eval_abstract);
|
||||
eval_trace_.pop_back();
|
||||
if (eval_trace_.empty()) {
|
||||
|
@ -729,7 +766,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
}
|
||||
}
|
||||
|
||||
return ProcessEvalResults(out_specs);
|
||||
return ProcessEvalResults(out_specs, out_conf->node());
|
||||
}
|
||||
|
||||
EvalResultPtr AnfNodeConfig::ObtainEvalResult() {
|
||||
|
|
|
@ -316,7 +316,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
|
||||
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
|
||||
bool *continue_flag);
|
||||
EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs);
|
||||
EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
|
||||
|
||||
const PrimEvaluatorMap &prim_constructors_;
|
||||
FuncGraphManagerPtr func_graph_manager_;
|
||||
|
|
|
@ -18,14 +18,73 @@
|
|||
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
|
||||
#include "utils/symbolic.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
AnfNodePtr GetTraceNode(const AbstractBasePtr &abs) {
|
||||
AnfNodePtr node = nullptr;
|
||||
if (abs->trace_node_provider_ != nullptr) {
|
||||
abs->trace_node_provider_(&node);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
inline void AbstractTypeJoinLogging(const AbstractBasePtr &abstract1, const AbstractBasePtr &abstract2) {
|
||||
std::ostringstream oss;
|
||||
oss << "Type Join Failed: abstract type " << abstract1->type_name() << " cannot not join with "
|
||||
<< abstract2->type_name() << ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
|
||||
<< "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
|
||||
auto node = GetTraceNode(abstract1);
|
||||
if (node != nullptr) {
|
||||
oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << oss.str();
|
||||
}
|
||||
|
||||
inline void TypeJoinLogging(const TypePtr &type1, const TypePtr &type2, const AbstractBasePtr &abstract1,
|
||||
const AbstractBasePtr &abstract2) {
|
||||
std::ostringstream oss;
|
||||
oss << "Type Join Failed: dtype1 = " << type1->ToString() << ", dtype2 = " << type2->ToString()
|
||||
<< ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
|
||||
<< "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
|
||||
auto node = GetTraceNode(abstract1);
|
||||
if (node != nullptr) {
|
||||
oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << oss.str();
|
||||
}
|
||||
|
||||
inline void ShapeJoinLogging(const BaseShapePtr &shape1, const BaseShapePtr &shape2, const AbstractBasePtr &abstract1,
|
||||
const AbstractBasePtr &abstract2) {
|
||||
std::ostringstream oss;
|
||||
oss << "Shape Join Failed: shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString()
|
||||
<< ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
|
||||
<< "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
|
||||
auto node = GetTraceNode(abstract1);
|
||||
if (node != nullptr) {
|
||||
oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
MS_EXCEPTION(ValueError) << oss.str();
|
||||
}
|
||||
|
||||
std::string ExtractLoggingInfo(const std::string &info) {
|
||||
// Extract log information based on the keyword "Type Join Failed" or "Shape Join Failed"
|
||||
std::regex e("(Type Join Failed|Shape Join Failed).*?\\.");
|
||||
std::smatch result;
|
||||
bool found = std::regex_search(info, result, e);
|
||||
if (found) {
|
||||
return result.str();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool AbstractBase::operator==(const AbstractBase &other) const {
|
||||
if (tid() != other.tid()) {
|
||||
return false;
|
||||
|
@ -112,15 +171,14 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
|||
if (*this == *other) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
auto value_self = GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_self);
|
||||
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
|
||||
auto value_other = other->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_other);
|
||||
auto type_self = GetTypeTrack();
|
||||
auto type_other = other->GetTypeTrack();
|
||||
TypePtr res_type = TypeJoin(type_self, type_other);
|
||||
if (res_type == kAnyType) {
|
||||
MS_LOG(ERROR) << "Type join failed, type1 = " << value_self->ToString() << ", type2 = " << value_other->ToString();
|
||||
return nullptr;
|
||||
TypeJoinLogging(type_self, type_other, shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
auto value_self = GetValueTrack();
|
||||
auto value_other = other->GetValueTrack();
|
||||
ValuePtr res_value = ValueJoin(value_self, value_other);
|
||||
if (res_value == value_self) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
|
@ -187,7 +245,7 @@ AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) {
|
|||
MS_EXCEPTION_IF_NULL(other);
|
||||
auto other_func = dyn_cast<AbstractFunction>(other);
|
||||
if (other_func == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
return Join(other_func);
|
||||
}
|
||||
|
@ -281,15 +339,10 @@ AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &other) {
|
|||
MS_EXCEPTION_IF_NULL(other);
|
||||
auto other_sequeue = dyn_cast<T>(other);
|
||||
if (other_sequeue == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
auto joined_list = AbstractJoin(elements_, other_sequeue->elements_);
|
||||
bool changes = false;
|
||||
if (elements_.size() > joined_list.size()) {
|
||||
MS_EXCEPTION(IndexError) << "Abstract " << ToString() << "'s size is " << elements_.size()
|
||||
<< " but element joined abstract " << other->ToString() << "with the the size "
|
||||
<< joined_list.size();
|
||||
}
|
||||
for (std::size_t i = 0; i < elements_.size(); i++) {
|
||||
if (elements_[i] != joined_list[i]) {
|
||||
changes = true;
|
||||
|
@ -469,29 +522,41 @@ BaseShapePtr AbstractTensor::BuildShape() const {
|
|||
AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
auto type = other->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(element_);
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
MS_EXCEPTION_IF_NULL(element_);
|
||||
|
||||
// AbstractTensor join with AbstractUndetermined
|
||||
if (type->type_id() == kObjectTypeUndeterminedType) {
|
||||
auto other_undetermined_tensor = dyn_cast<AbstractUndetermined>(other);
|
||||
MS_EXCEPTION_IF_NULL(other_undetermined_tensor);
|
||||
auto element = element_->Join(other_undetermined_tensor->element());
|
||||
if (element == nullptr) {
|
||||
return nullptr;
|
||||
// check shape
|
||||
auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape());
|
||||
if (res_shape == nullptr) {
|
||||
ShapeJoinLogging(shape(), other_undetermined_tensor->shape(), shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
return std::make_shared<AbstractUndetermined>(element, ShapeJoin(shape(), other_undetermined_tensor->shape()));
|
||||
// check element
|
||||
auto element = element_->Join(other_undetermined_tensor->element());
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
return std::make_shared<AbstractUndetermined>(element, res_shape);
|
||||
}
|
||||
|
||||
// AbstractTensor join with AbstractTensor
|
||||
auto other_tensor = dyn_cast<AbstractTensor>(other);
|
||||
if (other_tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
if (*this == *other) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
auto element = element_->Join(other_tensor->element_);
|
||||
if (element == nullptr) {
|
||||
return nullptr;
|
||||
// check shape
|
||||
auto res_shape = ShapeJoin(this->shape(), other_tensor->shape());
|
||||
if (res_shape == nullptr) {
|
||||
ShapeJoinLogging(shape(), other_tensor->shape(), shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(element, ShapeJoin(this->shape(), other_tensor->shape()));
|
||||
// check element
|
||||
auto element = element_->Join(other_tensor->element_);
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
return std::make_shared<AbstractTensor>(element, res_shape);
|
||||
}
|
||||
|
||||
bool AbstractTensor::equal_to(const AbstractTensor &other) const {
|
||||
|
@ -826,7 +891,7 @@ AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
|
|||
MS_EXCEPTION_IF_NULL(other);
|
||||
auto other_jtagged = dyn_cast<AbstractJTagged>(other);
|
||||
if (other_jtagged == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
auto joined_elem = element_->Join(other_jtagged->element_);
|
||||
return std::make_shared<AbstractJTagged>(joined_elem);
|
||||
|
@ -1328,15 +1393,14 @@ std::string AbstractSparseTensor::ToString() const {
|
|||
|
||||
AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
if (other->isa<AbstractUMonad>()) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
if (!other->isa<AbstractUMonad>()) {
|
||||
auto this_type = GetTypeTrack();
|
||||
auto other_type = other->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(this_type);
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
auto this_type = GetTypeTrack();
|
||||
auto other_type = other->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(this_type);
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << this_type->ToString()
|
||||
<< ", type2 = " << other_type->ToString();
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
|
||||
bool AbstractUMonad::operator==(const AbstractUMonad &) const { return true; }
|
||||
|
@ -1350,15 +1414,14 @@ bool AbstractUMonad::operator==(const AbstractBase &other) const {
|
|||
|
||||
AbstractBasePtr AbstractIOMonad::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
if (other->isa<AbstractIOMonad>()) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
if (!other->isa<AbstractIOMonad>()) {
|
||||
auto this_type = GetTypeTrack();
|
||||
auto other_type = other->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(this_type);
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
|
||||
}
|
||||
auto this_type = GetTypeTrack();
|
||||
auto other_type = other->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(this_type);
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << this_type->ToString()
|
||||
<< ", type2 = " << other_type->ToString();
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
|
||||
bool AbstractIOMonad::operator==(const AbstractIOMonad &) const { return true; }
|
||||
|
|
|
@ -45,6 +45,8 @@ using AbstractBasePtrList = std::vector<AbstractBasePtr>;
|
|||
// to express the type, shape, and value of the real value.
|
||||
class AbstractBase : public Base {
|
||||
public:
|
||||
using TraceNodeProvider = std::function<void(AnfNodePtr *node)>;
|
||||
|
||||
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
|
||||
const BaseShapePtr &shape = kNoShape)
|
||||
: value_(value), type_(type), shape_(shape) {}
|
||||
|
@ -72,6 +74,11 @@ class AbstractBase : public Base {
|
|||
virtual BaseShapePtr BuildShape() const { return kNoShape; }
|
||||
virtual AbstractBasePtr Clone() const = 0;
|
||||
|
||||
static void set_trace_node_provider(TraceNodeProvider trace_node_provider) {
|
||||
trace_node_provider_ = trace_node_provider;
|
||||
}
|
||||
|
||||
inline static TraceNodeProvider trace_node_provider_ = nullptr;
|
||||
// mask for Broaden config
|
||||
inline static const uint8_t kBroadenTensorOnly = 1;
|
||||
inline static const uint8_t kBroadenParameterOnly = 2;
|
||||
|
@ -760,6 +767,8 @@ class AbstractIOMonad : public AbstractMonad {
|
|||
};
|
||||
using AbstractIOMonadPtr = std::shared_ptr<AbstractIOMonad>;
|
||||
|
||||
AnfNodePtr GetTraceNode(const AbstractBasePtr &abs);
|
||||
std::string ExtractLoggingInfo(const std::string &info);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_
|
||||
|
|
|
@ -115,8 +115,7 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
|
|||
if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
|
||||
return shape2;
|
||||
}
|
||||
MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString()
|
||||
<< ", shape2 = " << shape2->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
ShapeVector dims;
|
||||
bool has_dynamic_shape = false;
|
||||
|
|
|
@ -889,7 +889,7 @@ def test_switch_layer_dtype_join_failed():
|
|||
inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
i = Tensor(0, mstype.int32)
|
||||
|
||||
with pytest.raises(Exception) as err:
|
||||
with pytest.raises(TypeError) as err:
|
||||
net(i, inp)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue