!18825 [ME][Compiler]Limite broaden when type is number
Merge pull request !18825 from chenfei_mindspore/namespace-err
This commit is contained in:
commit
476c268b13
|
@ -338,9 +338,8 @@ void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList
|
|||
auto &arg_value = op_args[i];
|
||||
auto arg_abs = arg_value->ToAbstract();
|
||||
if (!is_const_prim && !is_const_input) {
|
||||
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
||||
arg_abs = arg_abs->Broaden(config);
|
||||
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
|
||||
arg_abs = arg_abs->Broaden();
|
||||
MS_LOG(DEBUG) << "Broaden for " << prim->ToString();
|
||||
}
|
||||
(*abs_list).emplace_back(arg_abs);
|
||||
}
|
||||
|
@ -353,8 +352,7 @@ abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr
|
|||
}
|
||||
abstract::AbstractBasePtrList new_abs_list(abs_list);
|
||||
auto out_abs = out->ToAbstract();
|
||||
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
||||
out_abs = out_abs->Broaden(config);
|
||||
out_abs = out_abs->Broaden();
|
||||
new_abs_list.emplace_back(out_abs);
|
||||
new_abs_list.emplace_back(out_abs);
|
||||
return new_abs_list;
|
||||
|
|
|
@ -263,12 +263,7 @@ void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList
|
|||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
// Only broaden scalar that data type is number, such as float16,int32 and so on.
|
||||
auto type = arg->BuildType()->type_id();
|
||||
if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) {
|
||||
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
|
||||
return arg->Broaden(config);
|
||||
} else if (arg->GetValueTrack() != kAnyValue) {
|
||||
if (arg->GetValueTrack() != kAnyValue) {
|
||||
return arg->Broaden();
|
||||
}
|
||||
return arg;
|
||||
|
@ -280,7 +275,6 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
|
|||
if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
AbstractBasePtrList broaded_list;
|
||||
BroadenArgs(args_spec_list, &broaded_list);
|
||||
|
||||
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
|
||||
<< ", broaded: " << mindspore::ToString(broaded_list);
|
||||
return broaded_list;
|
||||
|
|
|
@ -873,10 +873,9 @@ void ForwardExecutor::GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info,
|
|||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
abs = input_value->ToAbstract();
|
||||
if (!is_const_prim && !is_const_input) {
|
||||
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
abs = abs->Broaden(config);
|
||||
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
|
||||
abs = abs->Broaden();
|
||||
MS_LOG(DEBUG) << "Broaden for " << prim->ToString();
|
||||
node_abs_map_[id] = abs;
|
||||
}
|
||||
}
|
||||
|
@ -1808,8 +1807,7 @@ std::string GradExecutor::GetCellId(const py::object &cell, const py::args &args
|
|||
cell_id += it->second->BuildType()->ToString();
|
||||
} else {
|
||||
auto abs = PyAttrValue(args[i])->ToAbstract();
|
||||
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
||||
abs = abs->Broaden(config);
|
||||
abs = abs->Broaden();
|
||||
forward()->node_abs_map()[arg_id] = abs;
|
||||
cell_id += "_" + abs->BuildShape()->ToString();
|
||||
cell_id += abs->BuildType()->ToString();
|
||||
|
|
|
@ -133,13 +133,10 @@ ValuePtr AbstractBase::BuildValue() const {
|
|||
return value_;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractBase::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractBase::Broaden() const {
|
||||
AbstractBasePtr clone = Clone();
|
||||
MS_EXCEPTION_IF_NULL(clone);
|
||||
auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly);
|
||||
if (not_broaden == 0) {
|
||||
clone->set_value(kAnyValue);
|
||||
}
|
||||
clone->set_value(kAnyValue);
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
@ -156,18 +153,17 @@ std::string AbstractBase::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractScalar::Broaden() const {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || config == kBroadenScalarParameterOnly) {
|
||||
return AbstractBase::Broaden(config);
|
||||
} else {
|
||||
auto type = this->BuildType()->type_id();
|
||||
if (type < kNumberTypeBegin || type > kNumberTypeEnd) {
|
||||
return AbstractBase::Broaden(config);
|
||||
}
|
||||
return Clone();
|
||||
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
|
||||
return AbstractBase::Broaden();
|
||||
}
|
||||
auto type_id = GetTypeTrack()->type_id();
|
||||
if (type_id == kObjectTypeEnvType) {
|
||||
return AbstractBase::Broaden();
|
||||
}
|
||||
return Clone();
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
||||
|
@ -311,11 +307,11 @@ AbstractBasePtrList AbstractSequeue::ElementsClone() const {
|
|||
return ele_list;
|
||||
}
|
||||
|
||||
AbstractBasePtrList AbstractSequeue::ElementsBroaden(uint8_t config) const {
|
||||
AbstractBasePtrList AbstractSequeue::ElementsBroaden() const {
|
||||
AbstractBasePtrList ele_list;
|
||||
for (const auto &ele : elements_) {
|
||||
MS_EXCEPTION_IF_NULL(ele);
|
||||
AbstractBasePtr broadend = ele->Broaden(config);
|
||||
AbstractBasePtr broadend = ele->Broaden();
|
||||
ele_list.push_back(broadend);
|
||||
}
|
||||
return ele_list;
|
||||
|
@ -457,13 +453,13 @@ AbstractBasePtr AbstractSlice::Clone() const {
|
|||
return std::make_shared<AbstractSlice>(start, stop, step);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractSlice::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractSlice::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(start_);
|
||||
MS_EXCEPTION_IF_NULL(stop_);
|
||||
MS_EXCEPTION_IF_NULL(step_);
|
||||
AbstractBasePtr start = start_->Broaden(config);
|
||||
AbstractBasePtr stop = stop_->Broaden(config);
|
||||
AbstractBasePtr step = step_->Broaden(config);
|
||||
AbstractBasePtr start = start_->Broaden();
|
||||
AbstractBasePtr stop = stop_->Broaden();
|
||||
AbstractBasePtr step = step_->Broaden();
|
||||
return std::make_shared<AbstractSlice>(start, stop, step);
|
||||
}
|
||||
|
||||
|
@ -508,6 +504,14 @@ ShapePtr AbstractUndetermined::shape() const {
|
|||
return shp;
|
||||
}
|
||||
|
||||
void AbstractUndetermined::set_shape(const BaseShapePtr &shape) {
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->isa<NoShape>()) {
|
||||
MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
|
||||
}
|
||||
AbstractBase::set_shape(shape);
|
||||
}
|
||||
|
||||
TypePtr AbstractTensor::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(element_);
|
||||
TypePtr element_type = element_->BuildType();
|
||||
|
@ -606,16 +610,13 @@ AbstractBasePtr AbstractTensor::Clone() const {
|
|||
return clone;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractTensor::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractTensor::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(element_);
|
||||
auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
|
||||
auto shp = shape();
|
||||
MS_EXCEPTION_IF_NULL(shp);
|
||||
broaden->set_shape(shp->Clone());
|
||||
auto not_broaden = config & kBroadenParameterOnly;
|
||||
if (not_broaden == 0) {
|
||||
broaden->set_value(kAnyValue);
|
||||
}
|
||||
broaden->set_value(kAnyValue);
|
||||
return broaden;
|
||||
}
|
||||
|
||||
|
@ -692,12 +693,12 @@ AbstractBasePtr AbstractDictionary::Clone() const {
|
|||
return std::make_shared<AbstractDictionary>(kv);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractDictionary::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractDictionary::Broaden() const {
|
||||
std::vector<AbstractAttribute> kv;
|
||||
(void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
|
||||
[config](const AbstractAttribute &item) {
|
||||
[](const AbstractAttribute &item) {
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
return std::make_pair(item.first, item.second->Broaden(config));
|
||||
return std::make_pair(item.first, item.second->Broaden());
|
||||
});
|
||||
return std::make_shared<AbstractDictionary>(kv);
|
||||
}
|
||||
|
@ -818,11 +819,11 @@ AbstractBasePtr AbstractClass::Clone() const {
|
|||
return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractClass::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractClass::Broaden() const {
|
||||
std::vector<AbstractAttribute> attributes_clone;
|
||||
for (const auto &attr : attributes_) {
|
||||
MS_EXCEPTION_IF_NULL(attr.second);
|
||||
AbstractBasePtr clone = attr.second->Broaden(config);
|
||||
AbstractBasePtr clone = attr.second->Broaden();
|
||||
AbstractAttribute elem(attr.first, clone);
|
||||
attributes_clone.push_back(elem);
|
||||
}
|
||||
|
@ -1022,15 +1023,6 @@ std::string AbstractNone::ToString() const {
|
|||
|
||||
ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
|
||||
|
||||
AbstractBasePtr AbstractRefKey::Broaden(uint8_t config) const {
|
||||
auto refkey = std::make_shared<AbstractRefKey>();
|
||||
auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly);
|
||||
if (not_broaden == 0) {
|
||||
refkey->set_value(kAnyValue);
|
||||
}
|
||||
return refkey;
|
||||
}
|
||||
|
||||
bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
|
||||
ValuePtr value_self = GetValueTrack();
|
||||
ValuePtr value_other = other.GetValueTrack();
|
||||
|
@ -1141,9 +1133,9 @@ AbstractBasePtr AbstractKeywordArg::Clone() const {
|
|||
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractKeywordArg::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractKeywordArg::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(arg_value_);
|
||||
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden(config));
|
||||
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
|
||||
}
|
||||
|
||||
std::size_t AbstractKeywordArg::hash() const {
|
||||
|
@ -1259,7 +1251,7 @@ AbstractBasePtr AbstractRowTensor::Clone() const {
|
|||
return clone;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractRowTensor::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractRowTensor::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
|
||||
auto shp = shape();
|
||||
|
@ -1351,7 +1343,7 @@ AbstractBasePtr AbstractSparseTensor::Clone() const {
|
|||
return clone;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractSparseTensor::Broaden(uint8_t config) const {
|
||||
AbstractBasePtr AbstractSparseTensor::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
|
||||
auto shp = shape();
|
||||
|
|
|
@ -59,7 +59,7 @@ class AbstractBase : public Base {
|
|||
virtual bool operator==(const AbstractBase &other) const;
|
||||
void set_value(const ValuePtr &value) { value_ = value; }
|
||||
void set_type(const TypePtr &type) { type_ = type; }
|
||||
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
|
||||
virtual void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
|
||||
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
|
||||
const std::string &value_desc() const { return value_desc_; }
|
||||
ValuePtr GetValueTrack() const { return value_; }
|
||||
|
@ -79,16 +79,7 @@ class AbstractBase : public Base {
|
|||
}
|
||||
|
||||
inline static TraceNodeProvider trace_node_provider_ = nullptr;
|
||||
// mask for Broaden config
|
||||
inline static const uint8_t kBroadenTensorOnly = 1;
|
||||
inline static const uint8_t kBroadenParameterOnly = 2;
|
||||
// Scalar as Parameter, should boarden
|
||||
inline static const uint8_t kBroadenScalarParameterOnly = 4;
|
||||
// Each bit for on config.
|
||||
// 00000001 -> 1: only boarden tensor
|
||||
// 00000010 -> 2: only boarden parameter
|
||||
// 00000100 -> 4: only boarden scalar parameter
|
||||
virtual AbstractBasePtr Broaden(uint8_t config = 0) const;
|
||||
virtual AbstractBasePtr Broaden() const;
|
||||
virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); }
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) {
|
||||
|
@ -128,7 +119,7 @@ class AbstractScalar : public AbstractBase {
|
|||
AbstractBasePtr Clone() const override {
|
||||
return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone());
|
||||
}
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
};
|
||||
using AbstractScalarPtr = std::shared_ptr<AbstractScalar>;
|
||||
|
@ -148,7 +139,7 @@ class AbstractType : public AbstractBase {
|
|||
|
||||
TypePtr BuildType() const override { return std::make_shared<TypeType>(); }
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); }
|
||||
AbstractBasePtr Broaden() const override { return Clone(); }
|
||||
};
|
||||
using AbstractTypePtr = std::shared_ptr<AbstractType>;
|
||||
|
||||
|
@ -163,7 +154,7 @@ class AbstractError : public AbstractBase {
|
|||
MS_DECLARE_PARENT(AbstractError, AbstractBase)
|
||||
|
||||
TypePtr BuildType() const override { return std::make_shared<Problem>(); }
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); }
|
||||
AbstractBasePtr Broaden() const override { return Clone(); }
|
||||
|
||||
AbstractBasePtr Clone() const override {
|
||||
return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_);
|
||||
|
@ -200,7 +191,7 @@ class AbstractFunction : public AbstractBase {
|
|||
TypePtr BuildType() const override { return std::make_shared<Function>(); }
|
||||
AbstractBasePtr Clone() const override { return Copy(); }
|
||||
// For Function, no need to broaden.
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override {
|
||||
AbstractBasePtr Broaden() const override {
|
||||
return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>();
|
||||
}
|
||||
virtual AbstractFunctionPtr Copy() const = 0;
|
||||
|
@ -229,7 +220,7 @@ class AbstractKeywordArg : public AbstractBase {
|
|||
|
||||
TypePtr BuildType() const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
std::size_t hash() const override;
|
||||
|
||||
bool operator==(const AbstractKeywordArg &other) const;
|
||||
|
@ -261,28 +252,37 @@ class AbstractUndetermined : public AbstractBase {
|
|||
if (element->isa<AbstractUndetermined>()) {
|
||||
MS_LOG(EXCEPTION) << "element type error";
|
||||
}
|
||||
set_shape(shape);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->isa<NoShape>()) {
|
||||
MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
|
||||
}
|
||||
AbstractBase::set_shape(shape);
|
||||
}
|
||||
AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape)
|
||||
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
|
||||
if (element_type == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "element_type is nullptr";
|
||||
}
|
||||
set_shape(std::make_shared<Shape>(shape));
|
||||
AbstractBase::set_shape(std::make_shared<Shape>(shape));
|
||||
}
|
||||
explicit AbstractUndetermined(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
|
||||
if (element_type == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "element_type is nullptr";
|
||||
}
|
||||
set_shape(shape);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->isa<NoShape>()) {
|
||||
MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
|
||||
}
|
||||
AbstractBase::set_shape(shape);
|
||||
}
|
||||
~AbstractUndetermined() override = default;
|
||||
MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase)
|
||||
TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); }
|
||||
AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); }
|
||||
const AbstractBasePtr element() const { return element_; }
|
||||
AbstractBasePtr element() const { return element_; }
|
||||
ShapePtr shape() const;
|
||||
void set_shape(const BaseShapePtr &shape) override;
|
||||
|
||||
protected:
|
||||
AbstractBasePtr element_;
|
||||
|
@ -310,7 +310,7 @@ class AbstractTensor : public AbstractUndetermined {
|
|||
TypePtr BuildType() const override;
|
||||
BaseShapePtr BuildShape() const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
AbstractBasePtr BroadenWithShape() const;
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
bool operator==(const AbstractTensor &other) const;
|
||||
|
@ -345,7 +345,7 @@ class AbstractSequeue : public AbstractBase {
|
|||
TypePtrList ElementsType() const;
|
||||
BaseShapePtrList ElementsShape() const;
|
||||
AbstractBasePtrList ElementsClone() const;
|
||||
AbstractBasePtrList ElementsBroaden(uint8_t config = 0) const;
|
||||
AbstractBasePtrList ElementsBroaden() const;
|
||||
|
||||
template <typename T>
|
||||
ValuePtr ElementsBuildValue() const;
|
||||
|
@ -379,9 +379,7 @@ class AbstractTuple : public AbstractSequeue {
|
|||
|
||||
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); }
|
||||
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override {
|
||||
return std::make_shared<AbstractTuple>(ElementsBroaden(config));
|
||||
}
|
||||
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractTuple>(ElementsBroaden()); }
|
||||
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); }
|
||||
|
||||
|
@ -408,9 +406,7 @@ class AbstractList : public AbstractSequeue {
|
|||
|
||||
AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); }
|
||||
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override {
|
||||
return std::make_shared<AbstractList>(ElementsBroaden(config));
|
||||
}
|
||||
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractList>(ElementsBroaden()); }
|
||||
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); }
|
||||
|
||||
|
@ -442,7 +438,7 @@ class AbstractClass : public AbstractBase {
|
|||
AbstractBasePtr GetAttribute(const std::string &name);
|
||||
ValuePtr GetMethod(const std::string &name);
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
std::string ToString() const override;
|
||||
Named tag() const { return tag_; }
|
||||
std::size_t hash() const override;
|
||||
|
@ -467,7 +463,7 @@ class AbstractDictionary : public AbstractBase {
|
|||
bool operator==(const AbstractDictionary &other) const;
|
||||
bool operator==(const AbstractBase &other) const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
std::string ToString() const override;
|
||||
std::size_t hash() const override;
|
||||
std::size_t size() const { return key_values_.size(); }
|
||||
|
@ -491,7 +487,7 @@ class AbstractSlice : public AbstractBase {
|
|||
bool operator==(const AbstractSlice &other) const;
|
||||
bool operator==(const AbstractBase &other) const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
std::string ToString() const override;
|
||||
std::size_t hash() const override;
|
||||
AbstractBasePtr start() const { return start_; }
|
||||
|
@ -517,9 +513,7 @@ class AbstractJTagged : public AbstractBase {
|
|||
|
||||
TypePtr BuildType() const override;
|
||||
AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); }
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override {
|
||||
return std::make_shared<AbstractJTagged>(element_->Broaden(config));
|
||||
}
|
||||
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractJTagged>(element_->Broaden()); }
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
|
||||
bool operator==(const AbstractJTagged &other) const;
|
||||
|
@ -616,7 +610,6 @@ class AbstractRefKey : public AbstractBase {
|
|||
}
|
||||
RefKeyPtr ref_key_value() const { return ref_key_value_; }
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
|
@ -628,7 +621,6 @@ using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
|
|||
class AbstractRef : public AbstractTensor {
|
||||
public:
|
||||
AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value);
|
||||
|
||||
~AbstractRef() override = default;
|
||||
MS_DECLARE_PARENT(AbstractRef, AbstractTensor)
|
||||
|
||||
|
@ -647,13 +639,13 @@ class AbstractRef : public AbstractTensor {
|
|||
inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }
|
||||
inline AbstractBasePtr ref_key() const { return ref_key_; }
|
||||
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override {
|
||||
AbstractBasePtr Broaden() const override {
|
||||
// always broaden for ref
|
||||
auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
|
||||
if (abs_tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), abs_tensor);
|
||||
return std::make_shared<AbstractRef>(ref_key_->Broaden(), abs_tensor);
|
||||
}
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
std::size_t hash() const override {
|
||||
|
@ -696,7 +688,7 @@ class AbstractRowTensor : public AbstractUndetermined {
|
|||
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||
TypePtr BuildType() const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
AbstractBasePtr BroadenWithShape() const;
|
||||
|
||||
std::string ToString() const override;
|
||||
|
@ -725,7 +717,7 @@ class AbstractSparseTensor : public AbstractUndetermined {
|
|||
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||
TypePtr BuildType() const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
AbstractBasePtr BroadenWithShape() const;
|
||||
|
||||
std::string ToString() const override;
|
||||
|
@ -743,7 +735,7 @@ class AbstractMonad : public AbstractBase {
|
|||
|
||||
std::size_t hash() const override { return hash_combine({tid()}); }
|
||||
TypePtr BuildType() const override { return GetTypeTrack(); }
|
||||
AbstractBasePtr Broaden(uint8_t config) const override { return AbstractBase::Broaden(config); }
|
||||
AbstractBasePtr Broaden() const override { return AbstractBase::Broaden(); }
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override = 0;
|
||||
std::string ToString() const override {
|
||||
std::ostringstream buffer;
|
||||
|
|
Loading…
Reference in New Issue