forked from mindspore-Ecosystem/mindspore
!26437 Optimize constant folding for pynative
Merge pull request !26437 from JoyLvliang/optimize_constant_folding_for_pynative
This commit is contained in:
commit
8d05a3870b
|
@ -304,9 +304,29 @@ py::object BuildValue(const ValuePtr &value_ptr) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
py::object AbstractTupleValueToPython(const AbstractTuplePtr &tuple_abs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_abs);
|
||||||
|
auto value = tuple_abs->BuildValue();
|
||||||
|
if (value->isa<AnyValue>()) {
|
||||||
|
return py::none();
|
||||||
|
}
|
||||||
|
const auto &elements = tuple_abs->elements();
|
||||||
|
size_t len = elements.size();
|
||||||
|
py::tuple value_tuple(len);
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
value_tuple[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
|
||||||
|
}
|
||||||
|
return std::move(value_tuple);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
|
||||||
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
||||||
MS_EXCEPTION_IF_NULL(arg_tuple);
|
MS_EXCEPTION_IF_NULL(arg_tuple);
|
||||||
|
auto dic = py::dict();
|
||||||
|
if (only_convert_value) {
|
||||||
|
dic[ATTR_VALUE] = AbstractTupleValueToPython(arg_tuple);
|
||||||
|
return dic;
|
||||||
|
}
|
||||||
size_t len = arg_tuple->size();
|
size_t len = arg_tuple->size();
|
||||||
py::tuple shape_tuple(len);
|
py::tuple shape_tuple(len);
|
||||||
py::tuple dtype_tuple(len);
|
py::tuple dtype_tuple(len);
|
||||||
|
@ -319,8 +339,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
||||||
bool dyn_value = false;
|
bool dyn_value = false;
|
||||||
|
|
||||||
for (size_t i = 0; i < len; i++) {
|
for (size_t i = 0; i < len; i++) {
|
||||||
auto arg = arg_tuple->elements()[i];
|
py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
|
||||||
py::dict out = ConvertAbstractToPython(arg);
|
|
||||||
shape_tuple[i] = out[ATTR_SHAPE];
|
shape_tuple[i] = out[ATTR_SHAPE];
|
||||||
dtype_tuple[i] = out[ATTR_DTYPE];
|
dtype_tuple[i] = out[ATTR_DTYPE];
|
||||||
value_tuple[i] = out[ATTR_VALUE];
|
value_tuple[i] = out[ATTR_VALUE];
|
||||||
|
@ -339,7 +358,6 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
||||||
dyn_shape = true;
|
dyn_shape = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto dic = py::dict();
|
|
||||||
dic[ATTR_SHAPE] = shape_tuple;
|
dic[ATTR_SHAPE] = shape_tuple;
|
||||||
dic[ATTR_DTYPE] = dtype_tuple;
|
dic[ATTR_DTYPE] = dtype_tuple;
|
||||||
MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
|
MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
|
||||||
|
@ -361,9 +379,29 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
||||||
return dic;
|
return dic;
|
||||||
}
|
}
|
||||||
|
|
||||||
py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
py::object AbstractListValueToPython(const AbstractListPtr &list_abs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(list_abs);
|
||||||
|
auto value = list_abs->BuildValue();
|
||||||
|
if (value->isa<AnyValue>()) {
|
||||||
|
return py::none();
|
||||||
|
}
|
||||||
|
const auto &elements = list_abs->elements();
|
||||||
|
size_t len = elements.size();
|
||||||
|
py::list value_list(len);
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
value_list[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
|
||||||
|
}
|
||||||
|
return std::move(value_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
|
||||||
auto arg_list = dyn_cast<AbstractList>(abs_base);
|
auto arg_list = dyn_cast<AbstractList>(abs_base);
|
||||||
MS_EXCEPTION_IF_NULL(arg_list);
|
MS_EXCEPTION_IF_NULL(arg_list);
|
||||||
|
auto dic = py::dict();
|
||||||
|
if (only_convert_value) {
|
||||||
|
dic[ATTR_VALUE] = AbstractListValueToPython(arg_list);
|
||||||
|
return dic;
|
||||||
|
}
|
||||||
size_t len = arg_list->size();
|
size_t len = arg_list->size();
|
||||||
py::list shape_list(len);
|
py::list shape_list(len);
|
||||||
py::list dtype_list(len);
|
py::list dtype_list(len);
|
||||||
|
@ -385,7 +423,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
||||||
dyn_shape = true;
|
dyn_shape = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto dic = py::dict();
|
|
||||||
dic[ATTR_SHAPE] = shape_list;
|
dic[ATTR_SHAPE] = shape_list;
|
||||||
dic[ATTR_DTYPE] = dtype_list;
|
dic[ATTR_DTYPE] = dtype_list;
|
||||||
MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
|
MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
|
||||||
|
@ -403,10 +441,14 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
||||||
return dic;
|
return dic;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
|
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) {
|
||||||
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
||||||
MS_EXCEPTION_IF_NULL(dic);
|
MS_EXCEPTION_IF_NULL(dic);
|
||||||
MS_EXCEPTION_IF_NULL(arg_tensor);
|
MS_EXCEPTION_IF_NULL(arg_tensor);
|
||||||
|
if (only_convert_value) {
|
||||||
|
(*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
|
||||||
|
return;
|
||||||
|
}
|
||||||
MS_EXCEPTION_IF_NULL(arg_tensor->shape());
|
MS_EXCEPTION_IF_NULL(arg_tensor->shape());
|
||||||
(*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
|
(*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
|
||||||
const auto &min_shape = arg_tensor->shape()->min_shape();
|
const auto &min_shape = arg_tensor->shape()->min_shape();
|
||||||
|
@ -477,11 +519,26 @@ TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_lis
|
||||||
}
|
}
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
|
||||||
MS_EXCEPTION_IF_NULL(abs_base);
|
MS_EXCEPTION_IF_NULL(abs_base);
|
||||||
auto dic = py::dict();
|
auto dic = py::dict();
|
||||||
if (abs_base->isa<AbstractTensor>()) {
|
if (abs_base->isa<AbstractTensor>()) {
|
||||||
ConvertAbstractTensorToPython(abs_base, &dic);
|
ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic);
|
||||||
|
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
|
||||||
|
ShapeVector shape;
|
||||||
|
dic[ATTR_SHAPE] = shape;
|
||||||
|
dic[ATTR_DTYPE] = abs_base->BuildType();
|
||||||
|
dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
|
||||||
|
} else if (abs_base->isa<AbstractTuple>()) {
|
||||||
|
return AbstractTupleToPython(abs_base, only_convert_value);
|
||||||
|
} else if (abs_base->isa<AbstractList>()) {
|
||||||
|
return AbstractListToPython(abs_base, only_convert_value);
|
||||||
|
} else if (abs_base->isa<AbstractSlice>()) {
|
||||||
|
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
|
||||||
|
ShapeVector shape;
|
||||||
|
dic[ATTR_SHAPE] = shape;
|
||||||
|
dic[ATTR_DTYPE] = arg_slice->BuildType();
|
||||||
|
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractRowTensor>()) {
|
} else if (abs_base->isa<AbstractRowTensor>()) {
|
||||||
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
|
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
|
||||||
dic[ATTR_SHAPE] = arg->shape()->shape();
|
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||||
|
@ -497,25 +554,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
||||||
dic[ATTR_SHAPE] = arg->shape()->shape();
|
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||||
dic[ATTR_DTYPE] = arg->BuildType();
|
dic[ATTR_DTYPE] = arg->BuildType();
|
||||||
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
|
|
||||||
ShapeVector shape;
|
|
||||||
dic[ATTR_SHAPE] = shape;
|
|
||||||
dic[ATTR_DTYPE] = abs_base->BuildType();
|
|
||||||
dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
|
|
||||||
} else if (abs_base->isa<AbstractSlice>()) {
|
|
||||||
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
|
|
||||||
ShapeVector shape;
|
|
||||||
dic[ATTR_SHAPE] = shape;
|
|
||||||
dic[ATTR_DTYPE] = arg_slice->BuildType();
|
|
||||||
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
|
|
||||||
} else if (abs_base->isa<AbstractEllipsis>()) {
|
} else if (abs_base->isa<AbstractEllipsis>()) {
|
||||||
dic[ATTR_SHAPE] = py::none();
|
dic[ATTR_SHAPE] = py::none();
|
||||||
dic[ATTR_DTYPE] = py::ellipsis();
|
dic[ATTR_DTYPE] = py::ellipsis();
|
||||||
dic[ATTR_VALUE] = py::ellipsis();
|
dic[ATTR_VALUE] = py::ellipsis();
|
||||||
} else if (abs_base->isa<AbstractTuple>()) {
|
|
||||||
return AbstractTupleToPython(abs_base);
|
|
||||||
} else if (abs_base->isa<AbstractList>()) {
|
|
||||||
return AbstractListToPython(abs_base);
|
|
||||||
} else if (abs_base->isa<AbstractNone>()) {
|
} else if (abs_base->isa<AbstractNone>()) {
|
||||||
dic[ATTR_SHAPE] = py::none();
|
dic[ATTR_SHAPE] = py::none();
|
||||||
dic[ATTR_DTYPE] = py::none();
|
dic[ATTR_DTYPE] = py::none();
|
||||||
|
|
|
@ -174,7 +174,7 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model);
|
||||||
|
|
||||||
void ClearPrimEvaluatorMap();
|
void ClearPrimEvaluatorMap();
|
||||||
|
|
||||||
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base);
|
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value = false);
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -1242,19 +1242,18 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
||||||
const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
|
const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
|
||||||
bool prim_cache_hit, py::object *ret) {
|
bool prim_cache_hit, py::object *ret) {
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
auto prim = op_exec_info->py_primitive;
|
const auto &prim = op_exec_info->py_primitive;
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
// Infer output value by constant folding
|
// Infer output value by constant folding
|
||||||
MS_EXCEPTION_IF_NULL(ret);
|
MS_EXCEPTION_IF_NULL(ret);
|
||||||
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract, true);
|
||||||
if (!output["value"].is_none()) {
|
if (!output[ATTR_VALUE].is_none()) {
|
||||||
*ret = output["value"];
|
*ret = output[ATTR_VALUE];
|
||||||
grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
|
grad()->RecordGradOpInfo(op_exec_info);
|
||||||
return;
|
return;
|
||||||
}
|
} else if (prim->is_const_prim()) {
|
||||||
if (prim->is_const_prim()) {
|
|
||||||
*ret = py::cast("");
|
*ret = py::cast("");
|
||||||
grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
|
grad()->RecordGradOpInfo(op_exec_info);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1296,7 +1295,7 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
||||||
node_abs_map_.clear();
|
node_abs_map_.clear();
|
||||||
}
|
}
|
||||||
// Record op info for judge whether the construct of cell has been changed
|
// Record op info for judge whether the construct of cell has been changed
|
||||||
grad()->RecordGradOpInfo(op_exec_info, out_real_value);
|
grad()->RecordGradOpInfo(op_exec_info);
|
||||||
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
|
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1740,13 +1739,12 @@ void GradExecutor::EnableOpGraphCache(bool is_enable) {
|
||||||
inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
|
inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) {
|
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
|
||||||
if (!grad_flag_) {
|
if (!grad_flag_) {
|
||||||
MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
|
MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
MS_EXCEPTION_IF_NULL(op_out);
|
|
||||||
std::string input_args_info;
|
std::string input_args_info;
|
||||||
// Record input args info (weight or data)
|
// Record input args info (weight or data)
|
||||||
for (const auto mask : op_exec_info->inputs_mask) {
|
for (const auto mask : op_exec_info->inputs_mask) {
|
||||||
|
@ -1761,11 +1759,12 @@ void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const Val
|
||||||
const auto &curr_op_num = top_cell()->op_num();
|
const auto &curr_op_num = top_cell()->op_num();
|
||||||
op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
|
op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
|
||||||
// The out shape is added to determine those ops that change the shape
|
// The out shape is added to determine those ops that change the shape
|
||||||
auto out_abs = op_out->ToAbstract();
|
const auto &out_abs = op_exec_info->abstract;
|
||||||
if (out_abs != nullptr) {
|
if (out_abs != nullptr) {
|
||||||
auto out_shape = out_abs->BuildShape()->ToString();
|
auto shape = out_abs->BuildShape();
|
||||||
if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) {
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
op_exec_info->op_info += "-" + out_shape;
|
if (!shape->isa<abstract::NoShape>() && !shape->IsDimZero()) {
|
||||||
|
op_exec_info->op_info += "-" + shape->ToString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
top_cell()->all_op_info() += "-" + op_exec_info->op_info;
|
top_cell()->all_op_info() += "-" + op_exec_info->op_info;
|
||||||
|
@ -3230,7 +3229,8 @@ void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::objec
|
||||||
// Identity op info for current running ms_func graph.
|
// Identity op info for current running ms_func graph.
|
||||||
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
|
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
|
||||||
op_exec_info->op_name = phase;
|
op_exec_info->op_name = phase;
|
||||||
RecordGradOpInfo(op_exec_info, actual_out_v);
|
op_exec_info->abstract = actual_out_v->ToAbstract();
|
||||||
|
RecordGradOpInfo(op_exec_info);
|
||||||
MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info;
|
MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info;
|
||||||
|
|
||||||
// Step 1: Update actual output tensors used in grad graph.
|
// Step 1: Update actual output tensors used in grad graph.
|
||||||
|
|
|
@ -201,7 +201,7 @@ class GradExecutor {
|
||||||
std::set<std::string> &forward_outputs_id() const { return top_cell()->outputs_id(); }
|
std::set<std::string> &forward_outputs_id() const { return top_cell()->outputs_id(); }
|
||||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||||
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
|
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info);
|
||||||
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
|
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
|
||||||
// Construct grad graph for ms_function
|
// Construct grad graph for ms_function
|
||||||
bool eliminate_forward() const { return eliminate_forward_; }
|
bool eliminate_forward() const { return eliminate_forward_; }
|
||||||
|
|
|
@ -69,6 +69,11 @@ class MS_CORE_API BaseShape : public Base {
|
||||||
/// \return True if the object's dimensions are dynamic, otherwise false.
|
/// \return True if the object's dimensions are dynamic, otherwise false.
|
||||||
virtual bool IsDynamic() const = 0;
|
virtual bool IsDynamic() const = 0;
|
||||||
|
|
||||||
|
/// \brief Whether the object's dimension is zero.
|
||||||
|
///
|
||||||
|
/// \return True if the object's dimension is zero, otherwise false.
|
||||||
|
virtual bool IsDimZero() const = 0;
|
||||||
|
|
||||||
/// \brief Whether the object's dimensions are unknown.
|
/// \brief Whether the object's dimensions are unknown.
|
||||||
///
|
///
|
||||||
/// \return True if the object's dimensions are unknown, otherwise false.
|
/// \return True if the object's dimensions are unknown, otherwise false.
|
||||||
|
@ -97,6 +102,8 @@ class MS_CORE_API NoShape final : public BaseShape {
|
||||||
|
|
||||||
bool IsDynamic() const override { return false; }
|
bool IsDynamic() const override { return false; }
|
||||||
|
|
||||||
|
bool IsDimZero() const override { return true; };
|
||||||
|
|
||||||
bool IsDimUnknown() const override { return false; }
|
bool IsDimUnknown() const override { return false; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -172,6 +179,8 @@ class MS_CORE_API Shape final : public BaseShape {
|
||||||
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; });
|
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsDimZero() const override { return shape_.empty(); };
|
||||||
|
|
||||||
bool IsDimUnknown() const override {
|
bool IsDimUnknown() const override {
|
||||||
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < -1; });
|
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < -1; });
|
||||||
}
|
}
|
||||||
|
@ -236,6 +245,10 @@ class MS_CORE_API SequenceShape : public BaseShape {
|
||||||
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); });
|
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsDimZero() const override {
|
||||||
|
return std::all_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimZero(); });
|
||||||
|
};
|
||||||
|
|
||||||
bool IsDimUnknown() const override {
|
bool IsDimUnknown() const override {
|
||||||
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimUnknown(); });
|
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimUnknown(); });
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue