forked from mindspore-Ecosystem/mindspore
[ge_adapter]: imple control flow(while while by while and case in while)
This commit is contained in:
parent
ca68555cc0
commit
3c81a6e1eb
|
@ -33,11 +33,13 @@
|
|||
#include "include/common/utils/config_manager.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "ops/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "include/transform/graph_ir/util.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "include/transform/graph_ir/df_graph_manager.h"
|
||||
#include "transform/graph_ir/op_adapter.h"
|
||||
#include "graph/operator_reg.h"
|
||||
#include "external/ge/ge_api.h"
|
||||
#include "graph/tensor.h"
|
||||
|
@ -51,6 +53,8 @@ using TensorOrderMap = std::map<std::string, std::shared_ptr<tensor::Tensor>>;
|
|||
using HcomBroadcast = ge::op::HcomBroadcast;
|
||||
using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>;
|
||||
|
||||
using ParamIndexMap = std::map<std::size_t, std::size_t>;
|
||||
enum class GraphType { kNormal, kCond, kBody, kAfter, kBranch };
|
||||
class COMMON_EXPORT DfGraphConvertor {
|
||||
public:
|
||||
explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph) {
|
||||
|
@ -101,6 +105,7 @@ class COMMON_EXPORT DfGraphConvertor {
|
|||
fout.close();
|
||||
#endif
|
||||
}
|
||||
|
||||
void DrawSaveCheckpointGraph(const std::string &name) {
|
||||
std::ofstream fout(name);
|
||||
if (!fout.is_open()) {
|
||||
|
@ -112,11 +117,13 @@ class COMMON_EXPORT DfGraphConvertor {
|
|||
}
|
||||
|
||||
DfGraphConvertor &ConvertAllNode();
|
||||
bool SetGraphInputs(const std::vector<Operator> &inputs);
|
||||
DfGraphConvertor &BuildGraph();
|
||||
DfGraphConvertor &InitParam(const TensorOrderMap &tensors);
|
||||
DfGraphConvertor &GenerateCheckpointGraph();
|
||||
DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors);
|
||||
void InitParamWithData(const TensorOrderMap &tensors);
|
||||
OutHandler GetNormalOpInput(const AnfNodePtr &pred);
|
||||
void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node);
|
||||
void SetupBroadcast(const std::shared_ptr<HcomBroadcast> &broadcast, const std::vector<GeTensorDesc> &broadcast_desc,
|
||||
const DfGraphPtr &broadcast_graph, std::vector<ge::Operator> broadcast_input);
|
||||
|
@ -150,6 +157,7 @@ class COMMON_EXPORT DfGraphConvertor {
|
|||
AnfNodePtr TraceMakeTuple(const CNodePtr &node, uint64_t index);
|
||||
AnfNodePtr TraceDepend(const CNodePtr &node);
|
||||
OutHandler TraceRealOp(AnfNodePtr node);
|
||||
OutHandler GetHandler(const AnfNodePtr &node);
|
||||
OutHandler GetHandler(const AnfNodePtr &node, const std::stack<uint64_t> &index_stack, AnfNode *const draw_index);
|
||||
OperatorPtr Convert(AnfNodePtr node);
|
||||
OperatorPtr ConvertCNode(CNodePtr node);
|
||||
|
@ -194,8 +202,27 @@ class COMMON_EXPORT DfGraphConvertor {
|
|||
void SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred, const OperatorPtr &src,
|
||||
int index);
|
||||
void UpdateTupleOutCache(void);
|
||||
AnfNodePtr TransformConstOp(const CNodePtr &node, AnfNodePtr pred);
|
||||
AnfNodePtr GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input);
|
||||
|
||||
void ConvertWhileNode(const CNodePtr &node);
|
||||
void CacheWhileGraph(const CNodePtr &cnode);
|
||||
void ConvertWhileBody(const AnfNodePtr &node);
|
||||
std::shared_ptr<std::vector<Operator>> GetWhileSubGraphInput();
|
||||
void BuildWhileSubGraph();
|
||||
void ConvertWhileCond(const AnfNodePtr &node);
|
||||
void ConvertWhileAfter(const AnfNodePtr &node);
|
||||
void BuildWhileAfterSubGraph();
|
||||
void GetCallNodeInputs(const CNodePtr &node);
|
||||
std::vector<Operator> GetWhileBodyOutputs();
|
||||
bool IsSubGraph() const { return graph_type_ == GraphType::kCond || graph_type_ == GraphType::kBody; }
|
||||
bool IsAfterGraph() const { return graph_type_ == GraphType::kAfter; }
|
||||
bool IsNormalGraph() const { return graph_type_ == GraphType::kNormal; }
|
||||
bool IsBranchGraph() const { return graph_type_ == GraphType::kBranch; }
|
||||
void SetParamIndexMap(const std::vector<AnfNodePtr> &graphs);
|
||||
void SetWhileOutputHandle(const OperatorPtr &prev_while_op);
|
||||
void GetWhileUsedInputIndex(const std::vector<AnfNodePtr> &graphs);
|
||||
|
||||
std::shared_ptr<AnfGraph> anf_graph_{nullptr};
|
||||
std::shared_ptr<DfGraph> df_graph_{nullptr};
|
||||
std::shared_ptr<DfGraph> init_graph_{nullptr};
|
||||
|
@ -214,6 +241,7 @@ class COMMON_EXPORT DfGraphConvertor {
|
|||
mindspore::HashMap<std::string, AnfNodePtr> params_;
|
||||
mindspore::HashMap<std::string, OperatorPtr> vars_;
|
||||
std::vector<std::pair<ge::Operator, std::string>> graph_outputs_;
|
||||
std::vector<AnfNodePtr> graph_anf_outputs_;
|
||||
std::vector<OperatorPtr> graph_const_inputs_;
|
||||
std::vector<OperatorPtr> init_ops_;
|
||||
std::vector<OperatorPtr> broadcast_ops_;
|
||||
|
@ -223,6 +251,36 @@ class COMMON_EXPORT DfGraphConvertor {
|
|||
bool training_ = false;
|
||||
bool distribute_ = false;
|
||||
bool use_inputs_ = false;
|
||||
|
||||
AnfNodePtr while_cond_node_ = nullptr;
|
||||
mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> while_dfgraph_cache_;
|
||||
|
||||
CNodePtr cur_while_node_ = nullptr;
|
||||
size_t cur_while_node_out_size_ = 0;
|
||||
mindspore::HashMap<size_t, OutHandler> while_const_input_index_;
|
||||
mindspore::HashMap<size_t, OutHandler> prev_while_const_input_index_;
|
||||
mindspore::HashMap<size_t, size_t> prev_cond_to_while_out_index_;
|
||||
mindspore::HashMap<OperatorPtr, std::shared_ptr<tensor::Tensor>> const_op_to_value_;
|
||||
AnfNodePtr prev_while_node_ = nullptr;
|
||||
size_t prev_while_node_out_size_ = 0;
|
||||
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<AnfNodePtr>> while_graph_cache_;
|
||||
mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<OutHandler>>> call_input_handle_cache_;
|
||||
mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<OutHandler>>> while_output_handle_cache_;
|
||||
AnfNodePtr call_node_in_while_body_ = nullptr;
|
||||
GraphType graph_type_ = GraphType::kNormal;
|
||||
|
||||
ParamIndexMap body_cond_map_;
|
||||
ParamIndexMap after_cond_map_;
|
||||
ParamIndexMap prev_after_cond_map_;
|
||||
mindspore::HashMap<size_t, OperatorPtr> subgraph_input_cache_;
|
||||
|
||||
std::set<size_t> while_used_input_index_;
|
||||
std::set<size_t> prev_while_used_input_index_;
|
||||
|
||||
mindspore::HashMap<size_t, OutHandler> bypass_node_prev_handle_cache_;
|
||||
mindspore::HashMap<size_t, OutHandler> bypass_node_handle_cache_;
|
||||
size_t case_call_input_size_;
|
||||
};
|
||||
} // namespace transform
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -349,6 +349,7 @@ constexpr const char kNameTensorArray[] = "TensorArray";
|
|||
constexpr const char kNameTensorArrayWrite[] = "TensorArrayWrite";
|
||||
constexpr const char kNameTensorArrayGather[] = "TensorArrayGather";
|
||||
constexpr const char kNameTensorMove[] = "TensorMove";
|
||||
constexpr const char kNameWhile[] = "While";
|
||||
|
||||
class OpAdapterDesc;
|
||||
|
||||
|
|
|
@ -234,6 +234,14 @@ class COMMON_EXPORT TransformUtil {
|
|||
}
|
||||
return dest;
|
||||
}
|
||||
|
||||
/*
|
||||
* Parameters:
|
||||
* anf_name: [string] the anf node name
|
||||
* Return:
|
||||
* [string] operator name
|
||||
* */
|
||||
static std::string NormOpName(const std::string &anf_name);
|
||||
};
|
||||
} // namespace transform
|
||||
} // namespace mindspore
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -120,6 +120,17 @@ Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, int index,
|
|||
return NOT_FOUND;
|
||||
}
|
||||
|
||||
Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr<std::vector<DfGraph>> &subgraphs) {
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
if (subgraph_map_.size() != subgraphs->size()) {
|
||||
return INVALID_ARGUMENT;
|
||||
}
|
||||
for (size_t i = 0; i < subgraphs->size(); i++) {
|
||||
subgraph_map_.at(i).set_subgraph(op, std::make_shared<DfGraph>((*subgraphs)[i]));
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OpAdapterImpl::SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) {
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -349,7 +360,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
|
|||
}
|
||||
|
||||
if (output_size != tuple_shp->shape().size()) {
|
||||
MS_LOG(ERROR) << "output_map is not equal tuple_shape size";
|
||||
MS_LOG(ERROR) << "output_map is not equal to tuple_shape size";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ class OpAdapterImpl {
|
|||
const mindspore::HashMap<int, DynInputDesc> &dyn_input_map,
|
||||
const mindspore::HashMap<int, OutputDesc> &output_map,
|
||||
const mindspore::HashMap<int, DynOutputDesc> &dyn_output_map,
|
||||
const mindspore::HashMap<int, SubGraphDesc> &subgraph_map,
|
||||
const mindspore::HashMap<int, DynSubGraphDesc> &dyn_subgraph_map,
|
||||
const mindspore::HashMap<std::string, AttrDesc> &attr_map,
|
||||
const mindspore::HashMap<std::string, int> &enum_map,
|
||||
|
@ -45,6 +46,7 @@ class OpAdapterImpl {
|
|||
dyn_input_map_(dyn_input_map),
|
||||
output_map_(output_map),
|
||||
dyn_output_map_(dyn_output_map),
|
||||
subgraph_map_(subgraph_map),
|
||||
dyn_subgraph_map_(dyn_subgraph_map),
|
||||
attr_map_(attr_map),
|
||||
enum_map_(enum_map),
|
||||
|
@ -65,6 +67,7 @@ class OpAdapterImpl {
|
|||
Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim);
|
||||
Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim);
|
||||
OperatorPtr GenerateCustomOp(const AnfNodePtr anf);
|
||||
Status SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr<std::vector<DfGraph>> &subgraphs);
|
||||
Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches);
|
||||
Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input);
|
||||
Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input);
|
||||
|
@ -100,6 +103,7 @@ class OpAdapterImpl {
|
|||
const mindspore::HashMap<int, DynInputDesc> &dyn_input_map_;
|
||||
const mindspore::HashMap<int, OutputDesc> &output_map_;
|
||||
const mindspore::HashMap<int, DynOutputDesc> &dyn_output_map_;
|
||||
const mindspore::HashMap<int, SubGraphDesc> &subgraph_map_;
|
||||
const mindspore::HashMap<int, DynSubGraphDesc> &dyn_subgraph_map_;
|
||||
const mindspore::HashMap<std::string, AttrDesc> &attr_map_;
|
||||
const mindspore::HashMap<std::string, int> &enum_map_;
|
||||
|
@ -116,14 +120,14 @@ class OpAdapter : public BaseOpAdapter {
|
|||
public:
|
||||
using OpType = T;
|
||||
OpAdapter()
|
||||
: impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_,
|
||||
: impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, subgraph_map_,
|
||||
dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, &cus_input_map_,
|
||||
&cus_output_map_, &extra_attr_, &name_counts_, this)) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
}
|
||||
explicit OpAdapter(const ExtraAttr &extra_attr)
|
||||
: extra_attr_(extra_attr),
|
||||
impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_,
|
||||
impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, subgraph_map_,
|
||||
dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, &cus_input_map_,
|
||||
&cus_output_map_, &extra_attr_, &name_counts_, this)) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
|
@ -148,8 +152,17 @@ class OpAdapter : public BaseOpAdapter {
|
|||
// There are duplicate names in ANF graph, do not assign ANF node name to GE
|
||||
// GE will generate unique name automatically
|
||||
if (anf != nullptr && anf->fullname_with_scope() != "") {
|
||||
MS_LOG(DEBUG) << anf->fullname_with_scope();
|
||||
op = std::make_shared<OpType>(anf->fullname_with_scope());
|
||||
auto name = anf->fullname_with_scope();
|
||||
MS_LOG(DEBUG) << name;
|
||||
string user_data_key = "subgraph_node";
|
||||
if (anf->has_user_data(user_data_key) && *(anf->user_data<bool>(user_data_key))) {
|
||||
auto norm_name = TransformUtil::NormOpName(name);
|
||||
if (norm_name != name) {
|
||||
MS_LOG(DEBUG) << "normalize anf name : " << name << " to operator name : " << norm_name;
|
||||
}
|
||||
name = norm_name;
|
||||
}
|
||||
op = std::make_shared<OpType>(name);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "no fullname_with_scope";
|
||||
op = std::make_shared<OpType>();
|
||||
|
@ -169,6 +182,28 @@ class OpAdapter : public BaseOpAdapter {
|
|||
return op;
|
||||
}
|
||||
|
||||
OperatorPtr GenerateDynamicOutputOp(const AnfNodePtr &anf) {
|
||||
OperatorPtr op = nullptr;
|
||||
// There are duplicate names in ANF graph, do not assign ANF node name to GE
|
||||
// GE will generate unique name automatically
|
||||
if (anf != nullptr && anf->fullname_with_scope() != "") {
|
||||
MS_LOG(DEBUG) << anf->fullname_with_scope();
|
||||
op = std::make_shared<OpType>(anf->fullname_with_scope());
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "no fullname_with_scope";
|
||||
op = std::make_shared<OpType>();
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
void setDynamicOutputNum(const OperatorPtr &op, size_t dyn_output_size) override {
|
||||
// set dynamic output num if op use DYNAMIC_OUTPUT
|
||||
if ((op != nullptr) && (!dyn_output_map_.empty())) {
|
||||
MS_LOG(DEBUG) << "create_dyn_output for node:" << op->GetName() << ", num:" << dyn_output_size;
|
||||
dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(dyn_output_size));
|
||||
}
|
||||
}
|
||||
|
||||
OperatorPtr generate(const AnfNodePtr &anf) override {
|
||||
OperatorPtr op = nullptr;
|
||||
if (IsCustomCNode(anf)) {
|
||||
|
@ -184,12 +219,30 @@ class OpAdapter : public BaseOpAdapter {
|
|||
|
||||
OperatorPtr generate(const std::string &op_name) override { return std::make_shared<OpType>(op_name); }
|
||||
|
||||
OperatorPtr generateDynOutputOp(const AnfNodePtr &anf) override {
|
||||
OperatorPtr op = nullptr;
|
||||
op = GenerateDynamicOutputOp(anf);
|
||||
if (op == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Can not generate op for " << anf->fullname_with_scope();
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
const mindspore::HashMap<int, InputDesc> &getInputMap() override { return input_map_; }
|
||||
const mindspore::HashMap<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; }
|
||||
const mindspore::HashMap<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; }
|
||||
const mindspore::HashMap<int, SubGraphDesc> &getSubgraphMap() override { return subgraph_map_; }
|
||||
const mindspore::HashMap<int, OutputDesc> &getOutputMap() override { return output_map_; }
|
||||
const mindspore::HashMap<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; }
|
||||
|
||||
Status SetOpSubgraphFunc(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) {
|
||||
return impl_->SetOpSubgraphFunc(op, subgraphs);
|
||||
}
|
||||
|
||||
int setSubgraph(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) override {
|
||||
return static_cast<int>(SetOpSubgraphFunc(op, subgraphs));
|
||||
}
|
||||
|
||||
Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) {
|
||||
return impl_->SetOpSubgraphFunc(op, index, branches);
|
||||
}
|
||||
|
@ -428,6 +481,7 @@ class OpAdapter : public BaseOpAdapter {
|
|||
static const mindspore::HashMap<int, DynInputDesc> dyn_input_map_;
|
||||
static const mindspore::HashMap<int, OutputDesc> output_map_;
|
||||
static const mindspore::HashMap<int, DynOutputDesc> dyn_output_map_;
|
||||
static const mindspore::HashMap<int, SubGraphDesc> subgraph_map_;
|
||||
static const mindspore::HashMap<int, DynSubGraphDesc> dyn_subgraph_map_;
|
||||
static const mindspore::HashMap<std::string, AttrDesc> attr_map_;
|
||||
static const mindspore::HashMap<std::string, int> enum_map_;
|
||||
|
@ -449,6 +503,8 @@ const mindspore::HashMap<int, OutputDesc> OpAdapter<T>::output_map_;
|
|||
template <typename T>
|
||||
const mindspore::HashMap<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
|
||||
template <typename T>
|
||||
const mindspore::HashMap<int, SubGraphDesc> OpAdapter<T>::subgraph_map_;
|
||||
template <typename T>
|
||||
const mindspore::HashMap<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
|
||||
template <typename T>
|
||||
const mindspore::HashMap<std::string, AttrDesc> OpAdapter<T>::attr_map_;
|
||||
|
@ -464,5 +520,4 @@ mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> OpAdapter<
|
|||
// specialization for method
|
||||
} // namespace transform
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_
|
||||
|
|
|
@ -63,7 +63,10 @@ using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr
|
|||
using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>;
|
||||
using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>;
|
||||
using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
|
||||
using UpdateDynOutputDescFunc = std::function<void(OperatorPtr, unsigned int, GeTensorDesc)>;
|
||||
using SubGraphFunc = std::function<void(OperatorPtr, DfGraphPtr)>;
|
||||
using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>;
|
||||
|
||||
using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>;
|
||||
|
||||
struct AttrDesc {
|
||||
|
@ -85,6 +88,11 @@ struct DynInputDesc {
|
|||
DynInputHandleFunc set_handle;
|
||||
};
|
||||
|
||||
struct SubGraphDesc {
|
||||
std::string name;
|
||||
SubGraphFunc set_subgraph;
|
||||
};
|
||||
|
||||
struct DynSubGraphDesc {
|
||||
std::string name;
|
||||
CreateDynSubGraphFunc create_dyn_subgraph;
|
||||
|
@ -99,6 +107,7 @@ struct OutputDesc {
|
|||
struct DynOutputDesc {
|
||||
std::string name;
|
||||
CreateDynOutputOpFunc create_dyn_output;
|
||||
UpdateDynOutputDescFunc update_dyn_output_desc;
|
||||
};
|
||||
|
||||
class BaseOpAdapter {
|
||||
|
@ -106,6 +115,9 @@ class BaseOpAdapter {
|
|||
virtual ~BaseOpAdapter() {}
|
||||
virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
|
||||
virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); }
|
||||
virtual OperatorPtr generateDynOutputOp(const AnfNodePtr &anf) { return nullptr; }
|
||||
virtual void setDynamicOutputNum(const OperatorPtr &op, size_t dyn_output_size) { return; }
|
||||
virtual int setSubgraph(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) = 0;
|
||||
virtual int setSubgraph(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) = 0;
|
||||
virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
|
||||
virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
|
||||
|
@ -130,6 +142,7 @@ class BaseOpAdapter {
|
|||
virtual const mindspore::HashMap<unsigned int, AttrDesc> &getInputAttrMap() = 0;
|
||||
virtual const mindspore::HashMap<int, DynInputDesc> &getDynInputMap() = 0;
|
||||
virtual const mindspore::HashMap<int, OutputDesc> &getOutputMap() = 0;
|
||||
virtual const mindspore::HashMap<int, SubGraphDesc> &getSubgraphMap() = 0;
|
||||
virtual const mindspore::HashMap<int, DynSubGraphDesc> &getDynSubgraphMap() = 0;
|
||||
void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
|
||||
const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
|
||||
|
|
|
@ -23,5 +23,12 @@ DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}};
|
|||
ATTR_MAP(Case) = EMPTY_ATTR_MAP;
|
||||
DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}};
|
||||
DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}};
|
||||
REG_ADPT_DESC(Case, kNameCase, ADPT_DESC(Case))
|
||||
REG_ADPT_DESC(Case, kNameCase, ADPT_DESC(Case));
|
||||
|
||||
// While
|
||||
DYN_INPUT_MAP(While) = {{1, DYN_INPUT_DESC(input)}};
|
||||
ATTR_MAP(While) = {{"parallel_iterations", ATTR_DESC(parallel_iterations, AnyTraits<int32_t>())}};
|
||||
DYN_OUTPUT_MAP(While) = {{0, DYN_OUTPUT_DESC(output)}};
|
||||
SUBGRAPH_MAP(While) = {{0, SUBGRAPH_DESC(cond)}, {1, SUBGRAPH_DESC(body)}};
|
||||
REG_ADPT_DESC(While, kNameWhile, ADPT_DESC(While));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -27,5 +27,11 @@ DECLARE_OP_ADAPTER(Case)
|
|||
DECLARE_OP_USE_DYN_INPUT(Case)
|
||||
DECLARE_OP_USE_DYN_SUBGRAPH(Case)
|
||||
DECLARE_OP_USE_DYN_OUTPUT(Case)
|
||||
|
||||
DECLARE_OP_TYPE(While)
|
||||
DECLARE_OP_ATTR(While)
|
||||
DECLARE_OP_USE_DYN_INPUT(While)
|
||||
DECLARE_OP_USE_SUBGRAPH(While)
|
||||
DECLARE_OP_USE_DYN_OUTPUT(While)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_FUNCTIONAL_OPS_DECLARE_H_
|
||||
|
|
|
@ -33,10 +33,20 @@ namespace mindspore::transform {
|
|||
template <> \
|
||||
const mindspore::HashMap<std::string, AttrDesc> OpAdapter<T>::attr_map_;
|
||||
|
||||
#define DECLARE_OP_TYPE(T) using T = ge::op::T;
|
||||
|
||||
#define DECLARE_OP_ATTR(T) \
|
||||
template <> \
|
||||
const mindspore::HashMap<std::string, AttrDesc> OpAdapter<T>::attr_map_;
|
||||
|
||||
#define DECLARE_OP_USE_OUTPUT(T) \
|
||||
template <> \
|
||||
const mindspore::HashMap<int, OutputDesc> OpAdapter<T>::output_map_;
|
||||
|
||||
#define DECLARE_OP_USE_SUBGRAPH(T) \
|
||||
template <> \
|
||||
const mindspore::HashMap<int, SubGraphDesc> OpAdapter<T>::subgraph_map_;
|
||||
|
||||
#define DECLARE_OP_USE_ENUM(T) \
|
||||
template <> \
|
||||
const mindspore::HashMap<std::string, int> OpAdapter<T>::enum_map_{};
|
||||
|
@ -57,6 +67,18 @@ namespace mindspore::transform {
|
|||
template <> \
|
||||
const mindspore::HashMap<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
|
||||
|
||||
#define SUBGRAPH_MAP(T) \
|
||||
template <> \
|
||||
const mindspore::HashMap<int, SubGraphDesc> OpAdapter<T>::subgraph_map_
|
||||
#define SUBGRAPH_DESC(name) \
|
||||
{ \
|
||||
#name, \
|
||||
[](const OperatorPtr op, const DfGraphPtr graph) { \
|
||||
auto p = std::static_pointer_cast<OpType>(op); \
|
||||
(void)p->set_subgraph_builder_##name([graph](){return *graph;}); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define INPUT_MAP(T) \
|
||||
template <> \
|
||||
const mindspore::HashMap<int, InputDesc> OpAdapter<T>::input_map_
|
||||
|
@ -147,13 +169,17 @@ namespace mindspore::transform {
|
|||
template <> \
|
||||
const mindspore::HashMap<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_
|
||||
|
||||
#define DYN_OUTPUT_DESC(name) \
|
||||
{ \
|
||||
#name, \
|
||||
[](const OperatorPtr op, unsigned int num) { \
|
||||
auto p = std::static_pointer_cast<OpType>(op); \
|
||||
(void)p->create_dynamic_output_##name(num); \
|
||||
} \
|
||||
#define DYN_OUTPUT_DESC(name) \
|
||||
{ \
|
||||
#name, \
|
||||
[](const OperatorPtr op, unsigned int num) { \
|
||||
auto p = std::static_pointer_cast<OpType>(op); \
|
||||
(void)p->create_dynamic_output_##name(num); \
|
||||
}, \
|
||||
[](const OperatorPtr op, uint32_t index, const GeTensorDesc tensor_desc) { \
|
||||
auto p = std::static_pointer_cast<OpType>(op); \
|
||||
(void)p->UpdateDynamicOutputDesc(#name, index, tensor_desc); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define ADPT_DESC_ONE(T) std::make_shared<OpAdapterDesc>(std::make_shared<OpAdapter<T>>())
|
||||
|
|
|
@ -452,5 +452,16 @@ std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string TransformUtil::NormOpName(const std::string &anf_name) {
|
||||
std::string str = anf_name.substr(anf_name.rfind("/") + 1);
|
||||
std::string ret;
|
||||
for (const auto &c : str) {
|
||||
if (std::isalnum(c) || c == '_' || c == '-') {
|
||||
ret += c;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
} // namespace transform
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_ascend_control_sink """
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
|
@ -191,6 +192,45 @@ class NotOperation(nn.Cell):
|
|||
return not x_sum
|
||||
|
||||
|
||||
class SimpleCell(nn.Cell):
|
||||
def __init__(self, i):
|
||||
super().__init__()
|
||||
self.i = i
|
||||
|
||||
def construct(self, x):
|
||||
return self.i * x
|
||||
|
||||
|
||||
class CellListInWhileByWhile(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cell_list = nn.CellList()
|
||||
self.cell_list.append(SimpleCell(4))
|
||||
self.cell_list.append(SimpleCell(5))
|
||||
self.cell_list.append(SimpleCell(6))
|
||||
|
||||
def construct(self, t, x):
|
||||
out = t
|
||||
while x < 3:
|
||||
out += 4
|
||||
x += 1
|
||||
x = 0
|
||||
while x < 3:
|
||||
add = self.cell_list[x](t)
|
||||
out = out + add
|
||||
x += 1
|
||||
return out
|
||||
|
||||
|
||||
def cell_list_in_while_by_while():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = CellListInWhileByWhile()
|
||||
t = Tensor(10, mstype.int32)
|
||||
x = Tensor(0, mstype.int32)
|
||||
out = net(t, x)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -347,3 +387,19 @@ def test_control_flow_ref():
|
|||
input_x = Tensor(6, ms.float32)
|
||||
out = net(input_x)
|
||||
assert out == 4
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_list_in_while_by_while_ge():
|
||||
"""
|
||||
Feature: Control flow(while and case) implement in ge
|
||||
Description: run the whole graph sink in ascend in ge backend
|
||||
Expectation: success
|
||||
"""
|
||||
os.environ['MS_ENABLE_GE'] = "1"
|
||||
out = cell_list_in_while_by_while()
|
||||
assert out == Tensor(172, mstype.int32)
|
||||
del os.environ['MS_ENABLE_GE']
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
@ -50,6 +50,33 @@ class CaseNet(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class SimpleCell(nn.Cell):
|
||||
def __init__(self, i):
|
||||
super().__init__()
|
||||
self.i = i
|
||||
|
||||
def construct(self, x):
|
||||
return self.i * x
|
||||
|
||||
|
||||
class CellInList(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cell_list = nn.CellList()
|
||||
self.cell_list.append(SimpleCell(4))
|
||||
self.cell_list.append(SimpleCell(5))
|
||||
self.cell_list.append(SimpleCell(6))
|
||||
|
||||
def construct(self, t, x):
|
||||
out = t
|
||||
while x < 3:
|
||||
add = self.cell_list[x](t)
|
||||
out = out + add
|
||||
x += 1
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -79,31 +106,6 @@ def test_cell_in_list():
|
|||
Description: test recursive switch layer.
|
||||
Expectation: success if grad and output are correct.
|
||||
"""
|
||||
|
||||
class TestCell(nn.Cell):
|
||||
def __init__(self, i):
|
||||
super().__init__()
|
||||
self.i = i
|
||||
|
||||
def construct(self, x):
|
||||
return self.i * x
|
||||
|
||||
class CellInList(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cell_list = nn.CellList()
|
||||
self.cell_list.append(TestCell(4))
|
||||
self.cell_list.append(TestCell(5))
|
||||
self.cell_list.append(TestCell(6))
|
||||
|
||||
def construct(self, t, x):
|
||||
out = t
|
||||
while x < 3:
|
||||
add = self.cell_list[x](t)
|
||||
out = out + add
|
||||
x += 1
|
||||
return out
|
||||
|
||||
net = CellInList()
|
||||
t = Tensor(10, mstype.int32)
|
||||
x = Tensor(0, mstype.int32)
|
||||
|
@ -113,3 +115,22 @@ def test_cell_in_list():
|
|||
|
||||
assert out == Tensor(160, mstype.int32)
|
||||
assert grad_out == Tensor(16, mstype.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_in_list_ge():
|
||||
"""
|
||||
Feature: Switch layer in while in ge backend.
|
||||
Description: test recursive switch layer in ge backend.
|
||||
Expectation: success.
|
||||
"""
|
||||
os.environ['MS_ENABLE_GE'] = "1"
|
||||
net = CellInList()
|
||||
t = Tensor(20, mstype.int32)
|
||||
x = Tensor(0, mstype.int32)
|
||||
out = net(t, x)
|
||||
del os.environ['MS_ENABLE_GE']
|
||||
assert out == Tensor(320, mstype.int32)
|
||||
|
|
Loading…
Reference in New Issue