!18825 [ME][Compiler]Limite broaden when type is number

Merge pull request !18825 from chenfei_mindspore/namespace-err
This commit is contained in:
i-robot 2021-07-05 01:30:25 +00:00 committed by Gitee
commit 476c268b13
5 changed files with 75 additions and 101 deletions

View File

@ -338,9 +338,8 @@ void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList
auto &arg_value = op_args[i];
auto arg_abs = arg_value->ToAbstract();
if (!is_const_prim && !is_const_input) {
auto config = abstract::AbstractBase::kBroadenTensorOnly;
arg_abs = arg_abs->Broaden(config);
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
arg_abs = arg_abs->Broaden();
MS_LOG(DEBUG) << "Broaden for " << prim->ToString();
}
(*abs_list).emplace_back(arg_abs);
}
@ -353,8 +352,7 @@ abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr
}
abstract::AbstractBasePtrList new_abs_list(abs_list);
auto out_abs = out->ToAbstract();
auto config = abstract::AbstractBase::kBroadenTensorOnly;
out_abs = out_abs->Broaden(config);
out_abs = out_abs->Broaden();
new_abs_list.emplace_back(out_abs);
new_abs_list.emplace_back(out_abs);
return new_abs_list;

View File

@ -263,12 +263,7 @@ void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
// Only broaden scalar that data type is number, such as float16,int32 and so on.
auto type = arg->BuildType()->type_id();
if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) {
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
return arg->Broaden(config);
} else if (arg->GetValueTrack() != kAnyValue) {
if (arg->GetValueTrack() != kAnyValue) {
return arg->Broaden();
}
return arg;
@ -280,7 +275,6 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
AbstractBasePtrList broaded_list;
BroadenArgs(args_spec_list, &broaded_list);
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
<< ", broaded: " << mindspore::ToString(broaded_list);
return broaded_list;

View File

@ -873,10 +873,9 @@ void ForwardExecutor::GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(input_value);
abs = input_value->ToAbstract();
if (!is_const_prim && !is_const_input) {
auto config = abstract::AbstractBase::kBroadenTensorOnly;
MS_EXCEPTION_IF_NULL(abs);
abs = abs->Broaden(config);
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
abs = abs->Broaden();
MS_LOG(DEBUG) << "Broaden for " << prim->ToString();
node_abs_map_[id] = abs;
}
}
@ -1808,8 +1807,7 @@ std::string GradExecutor::GetCellId(const py::object &cell, const py::args &args
cell_id += it->second->BuildType()->ToString();
} else {
auto abs = PyAttrValue(args[i])->ToAbstract();
auto config = abstract::AbstractBase::kBroadenTensorOnly;
abs = abs->Broaden(config);
abs = abs->Broaden();
forward()->node_abs_map()[arg_id] = abs;
cell_id += "_" + abs->BuildShape()->ToString();
cell_id += abs->BuildType()->ToString();

View File

@ -133,13 +133,10 @@ ValuePtr AbstractBase::BuildValue() const {
return value_;
}
AbstractBasePtr AbstractBase::Broaden(uint8_t config) const {
AbstractBasePtr AbstractBase::Broaden() const {
AbstractBasePtr clone = Clone();
MS_EXCEPTION_IF_NULL(clone);
auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly);
if (not_broaden == 0) {
clone->set_value(kAnyValue);
}
clone->set_value(kAnyValue);
return clone;
}
@ -156,18 +153,17 @@ std::string AbstractBase::ToString() const {
return buffer.str();
}
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
AbstractBasePtr AbstractScalar::Broaden() const {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || config == kBroadenScalarParameterOnly) {
return AbstractBase::Broaden(config);
} else {
auto type = this->BuildType()->type_id();
if (type < kNumberTypeBegin || type > kNumberTypeEnd) {
return AbstractBase::Broaden(config);
}
return Clone();
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
return AbstractBase::Broaden();
}
auto type_id = GetTypeTrack()->type_id();
if (type_id == kObjectTypeEnvType) {
return AbstractBase::Broaden();
}
return Clone();
}
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
@ -311,11 +307,11 @@ AbstractBasePtrList AbstractSequeue::ElementsClone() const {
return ele_list;
}
AbstractBasePtrList AbstractSequeue::ElementsBroaden(uint8_t config) const {
AbstractBasePtrList AbstractSequeue::ElementsBroaden() const {
AbstractBasePtrList ele_list;
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
AbstractBasePtr broadend = ele->Broaden(config);
AbstractBasePtr broadend = ele->Broaden();
ele_list.push_back(broadend);
}
return ele_list;
@ -457,13 +453,13 @@ AbstractBasePtr AbstractSlice::Clone() const {
return std::make_shared<AbstractSlice>(start, stop, step);
}
AbstractBasePtr AbstractSlice::Broaden(uint8_t config) const {
AbstractBasePtr AbstractSlice::Broaden() const {
MS_EXCEPTION_IF_NULL(start_);
MS_EXCEPTION_IF_NULL(stop_);
MS_EXCEPTION_IF_NULL(step_);
AbstractBasePtr start = start_->Broaden(config);
AbstractBasePtr stop = stop_->Broaden(config);
AbstractBasePtr step = step_->Broaden(config);
AbstractBasePtr start = start_->Broaden();
AbstractBasePtr stop = stop_->Broaden();
AbstractBasePtr step = step_->Broaden();
return std::make_shared<AbstractSlice>(start, stop, step);
}
@ -508,6 +504,14 @@ ShapePtr AbstractUndetermined::shape() const {
return shp;
}
void AbstractUndetermined::set_shape(const BaseShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<NoShape>()) {
MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
}
AbstractBase::set_shape(shape);
}
TypePtr AbstractTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element_);
TypePtr element_type = element_->BuildType();
@ -606,16 +610,13 @@ AbstractBasePtr AbstractTensor::Clone() const {
return clone;
}
AbstractBasePtr AbstractTensor::Broaden(uint8_t config) const {
AbstractBasePtr AbstractTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element_);
auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
auto shp = shape();
MS_EXCEPTION_IF_NULL(shp);
broaden->set_shape(shp->Clone());
auto not_broaden = config & kBroadenParameterOnly;
if (not_broaden == 0) {
broaden->set_value(kAnyValue);
}
broaden->set_value(kAnyValue);
return broaden;
}
@ -692,12 +693,12 @@ AbstractBasePtr AbstractDictionary::Clone() const {
return std::make_shared<AbstractDictionary>(kv);
}
AbstractBasePtr AbstractDictionary::Broaden(uint8_t config) const {
AbstractBasePtr AbstractDictionary::Broaden() const {
std::vector<AbstractAttribute> kv;
(void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[config](const AbstractAttribute &item) {
[](const AbstractAttribute &item) {
MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first, item.second->Broaden(config));
return std::make_pair(item.first, item.second->Broaden());
});
return std::make_shared<AbstractDictionary>(kv);
}
@ -818,11 +819,11 @@ AbstractBasePtr AbstractClass::Clone() const {
return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
}
AbstractBasePtr AbstractClass::Broaden(uint8_t config) const {
AbstractBasePtr AbstractClass::Broaden() const {
std::vector<AbstractAttribute> attributes_clone;
for (const auto &attr : attributes_) {
MS_EXCEPTION_IF_NULL(attr.second);
AbstractBasePtr clone = attr.second->Broaden(config);
AbstractBasePtr clone = attr.second->Broaden();
AbstractAttribute elem(attr.first, clone);
attributes_clone.push_back(elem);
}
@ -1022,15 +1023,6 @@ std::string AbstractNone::ToString() const {
ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
AbstractBasePtr AbstractRefKey::Broaden(uint8_t config) const {
auto refkey = std::make_shared<AbstractRefKey>();
auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly);
if (not_broaden == 0) {
refkey->set_value(kAnyValue);
}
return refkey;
}
bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
ValuePtr value_self = GetValueTrack();
ValuePtr value_other = other.GetValueTrack();
@ -1141,9 +1133,9 @@ AbstractBasePtr AbstractKeywordArg::Clone() const {
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
}
AbstractBasePtr AbstractKeywordArg::Broaden(uint8_t config) const {
AbstractBasePtr AbstractKeywordArg::Broaden() const {
MS_EXCEPTION_IF_NULL(arg_value_);
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden(config));
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
}
std::size_t AbstractKeywordArg::hash() const {
@ -1259,7 +1251,7 @@ AbstractBasePtr AbstractRowTensor::Clone() const {
return clone;
}
AbstractBasePtr AbstractRowTensor::Broaden(uint8_t config) const {
AbstractBasePtr AbstractRowTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
auto shp = shape();
@ -1351,7 +1343,7 @@ AbstractBasePtr AbstractSparseTensor::Clone() const {
return clone;
}
AbstractBasePtr AbstractSparseTensor::Broaden(uint8_t config) const {
AbstractBasePtr AbstractSparseTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape();

View File

@ -59,7 +59,7 @@ class AbstractBase : public Base {
virtual bool operator==(const AbstractBase &other) const;
void set_value(const ValuePtr &value) { value_ = value; }
void set_type(const TypePtr &type) { type_ = type; }
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
virtual void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
const std::string &value_desc() const { return value_desc_; }
ValuePtr GetValueTrack() const { return value_; }
@ -79,16 +79,7 @@ class AbstractBase : public Base {
}
inline static TraceNodeProvider trace_node_provider_ = nullptr;
// mask for Broaden config
inline static const uint8_t kBroadenTensorOnly = 1;
inline static const uint8_t kBroadenParameterOnly = 2;
// Scalar as Parameter, should boarden
inline static const uint8_t kBroadenScalarParameterOnly = 4;
// Each bit for on config.
// 00000001 -> 1: only boarden tensor
// 00000010 -> 2: only boarden parameter
// 00000100 -> 4: only boarden scalar parameter
virtual AbstractBasePtr Broaden(uint8_t config = 0) const;
virtual AbstractBasePtr Broaden() const;
virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); }
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) {
@ -128,7 +119,7 @@ class AbstractScalar : public AbstractBase {
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone());
}
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Join(const AbstractBasePtr &other) override;
};
using AbstractScalarPtr = std::shared_ptr<AbstractScalar>;
@ -148,7 +139,7 @@ class AbstractType : public AbstractBase {
TypePtr BuildType() const override { return std::make_shared<TypeType>(); }
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); }
AbstractBasePtr Broaden() const override { return Clone(); }
};
using AbstractTypePtr = std::shared_ptr<AbstractType>;
@ -163,7 +154,7 @@ class AbstractError : public AbstractBase {
MS_DECLARE_PARENT(AbstractError, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<Problem>(); }
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); }
AbstractBasePtr Broaden() const override { return Clone(); }
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_);
@ -200,7 +191,7 @@ class AbstractFunction : public AbstractBase {
TypePtr BuildType() const override { return std::make_shared<Function>(); }
AbstractBasePtr Clone() const override { return Copy(); }
// For Function, no need to broaden.
AbstractBasePtr Broaden(uint8_t config = 0) const override {
AbstractBasePtr Broaden() const override {
return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>();
}
virtual AbstractFunctionPtr Copy() const = 0;
@ -229,7 +220,7 @@ class AbstractKeywordArg : public AbstractBase {
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Broaden() const override;
std::size_t hash() const override;
bool operator==(const AbstractKeywordArg &other) const;
@ -261,28 +252,37 @@ class AbstractUndetermined : public AbstractBase {
if (element->isa<AbstractUndetermined>()) {
MS_LOG(EXCEPTION) << "element type error";
}
set_shape(shape);
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<NoShape>()) {
MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
}
AbstractBase::set_shape(shape);
}
AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape)
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
if (element_type == nullptr) {
MS_LOG(EXCEPTION) << "element_type is nullptr";
}
set_shape(std::make_shared<Shape>(shape));
AbstractBase::set_shape(std::make_shared<Shape>(shape));
}
explicit AbstractUndetermined(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
if (element_type == nullptr) {
MS_LOG(EXCEPTION) << "element_type is nullptr";
}
set_shape(shape);
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<NoShape>()) {
MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
}
AbstractBase::set_shape(shape);
}
~AbstractUndetermined() override = default;
MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); }
AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); }
const AbstractBasePtr element() const { return element_; }
AbstractBasePtr element() const { return element_; }
ShapePtr shape() const;
void set_shape(const BaseShapePtr &shape) override;
protected:
AbstractBasePtr element_;
@ -310,7 +310,7 @@ class AbstractTensor : public AbstractUndetermined {
TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
AbstractBasePtr Join(const AbstractBasePtr &other) override;
bool operator==(const AbstractTensor &other) const;
@ -345,7 +345,7 @@ class AbstractSequeue : public AbstractBase {
TypePtrList ElementsType() const;
BaseShapePtrList ElementsShape() const;
AbstractBasePtrList ElementsClone() const;
AbstractBasePtrList ElementsBroaden(uint8_t config = 0) const;
AbstractBasePtrList ElementsBroaden() const;
template <typename T>
ValuePtr ElementsBuildValue() const;
@ -379,9 +379,7 @@ class AbstractTuple : public AbstractSequeue {
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractTuple>(ElementsBroaden(config));
}
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractTuple>(ElementsBroaden()); }
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); }
@ -408,9 +406,7 @@ class AbstractList : public AbstractSequeue {
AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractList>(ElementsBroaden(config));
}
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractList>(ElementsBroaden()); }
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); }
@ -442,7 +438,7 @@ class AbstractClass : public AbstractBase {
AbstractBasePtr GetAttribute(const std::string &name);
ValuePtr GetMethod(const std::string &name);
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Broaden() const override;
std::string ToString() const override;
Named tag() const { return tag_; }
std::size_t hash() const override;
@ -467,7 +463,7 @@ class AbstractDictionary : public AbstractBase {
bool operator==(const AbstractDictionary &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Broaden() const override;
std::string ToString() const override;
std::size_t hash() const override;
std::size_t size() const { return key_values_.size(); }
@ -491,7 +487,7 @@ class AbstractSlice : public AbstractBase {
bool operator==(const AbstractSlice &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Broaden() const override;
std::string ToString() const override;
std::size_t hash() const override;
AbstractBasePtr start() const { return start_; }
@ -517,9 +513,7 @@ class AbstractJTagged : public AbstractBase {
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractJTagged>(element_->Broaden(config));
}
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractJTagged>(element_->Broaden()); }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
bool operator==(const AbstractJTagged &other) const;
@ -616,7 +610,6 @@ class AbstractRefKey : public AbstractBase {
}
RefKeyPtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override;
private:
@ -628,7 +621,6 @@ using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
class AbstractRef : public AbstractTensor {
public:
AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value);
~AbstractRef() override = default;
MS_DECLARE_PARENT(AbstractRef, AbstractTensor)
@ -647,13 +639,13 @@ class AbstractRef : public AbstractTensor {
inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }
inline AbstractBasePtr ref_key() const { return ref_key_; }
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
AbstractBasePtr Broaden() const override {
// always broaden for ref
auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
if (abs_tensor == nullptr) {
return nullptr;
}
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), abs_tensor);
return std::make_shared<AbstractRef>(ref_key_->Broaden(), abs_tensor);
}
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override {
@ -696,7 +688,7 @@ class AbstractRowTensor : public AbstractUndetermined {
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
@ -725,7 +717,7 @@ class AbstractSparseTensor : public AbstractUndetermined {
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
@ -743,7 +735,7 @@ class AbstractMonad : public AbstractBase {
std::size_t hash() const override { return hash_combine({tid()}); }
TypePtr BuildType() const override { return GetTypeTrack(); }
AbstractBasePtr Broaden(uint8_t config) const override { return AbstractBase::Broaden(config); }
AbstractBasePtr Broaden() const override { return AbstractBase::Broaden(); }
AbstractBasePtr Join(const AbstractBasePtr &other) override = 0;
std::string ToString() const override {
std::ostringstream buffer;