[ge_adapter]: imple control flow(while while by while and case in while)

This commit is contained in:
xiao_yao1994 2022-03-04 11:29:36 +08:00
parent ca68555cc0
commit 3c81a6e1eb
13 changed files with 1312 additions and 149 deletions

View File

@ -33,11 +33,13 @@
#include "include/common/utils/config_manager.h"
#include "utils/hash_map.h"
#include "utils/ms_context.h"
#include "ops/core_ops.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "include/transform/graph_ir/util.h"
#include "ir/tensor.h"
#include "include/transform/graph_ir/df_graph_manager.h"
#include "transform/graph_ir/op_adapter.h"
#include "graph/operator_reg.h"
#include "external/ge/ge_api.h"
#include "graph/tensor.h"
@ -51,6 +53,8 @@ using TensorOrderMap = std::map<std::string, std::shared_ptr<tensor::Tensor>>;
using HcomBroadcast = ge::op::HcomBroadcast;
using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>;
using ParamIndexMap = std::map<std::size_t, std::size_t>;
enum class GraphType { kNormal, kCond, kBody, kAfter, kBranch };
class COMMON_EXPORT DfGraphConvertor {
public:
explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph) {
@ -101,6 +105,7 @@ class COMMON_EXPORT DfGraphConvertor {
fout.close();
#endif
}
void DrawSaveCheckpointGraph(const std::string &name) {
std::ofstream fout(name);
if (!fout.is_open()) {
@ -112,11 +117,13 @@ class COMMON_EXPORT DfGraphConvertor {
}
DfGraphConvertor &ConvertAllNode();
bool SetGraphInputs(const std::vector<Operator> &inputs);
DfGraphConvertor &BuildGraph();
DfGraphConvertor &InitParam(const TensorOrderMap &tensors);
DfGraphConvertor &GenerateCheckpointGraph();
DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors);
void InitParamWithData(const TensorOrderMap &tensors);
OutHandler GetNormalOpInput(const AnfNodePtr &pred);
void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node);
void SetupBroadcast(const std::shared_ptr<HcomBroadcast> &broadcast, const std::vector<GeTensorDesc> &broadcast_desc,
const DfGraphPtr &broadcast_graph, std::vector<ge::Operator> broadcast_input);
@ -150,6 +157,7 @@ class COMMON_EXPORT DfGraphConvertor {
AnfNodePtr TraceMakeTuple(const CNodePtr &node, uint64_t index);
AnfNodePtr TraceDepend(const CNodePtr &node);
OutHandler TraceRealOp(AnfNodePtr node);
OutHandler GetHandler(const AnfNodePtr &node);
OutHandler GetHandler(const AnfNodePtr &node, const std::stack<uint64_t> &index_stack, AnfNode *const draw_index);
OperatorPtr Convert(AnfNodePtr node);
OperatorPtr ConvertCNode(CNodePtr node);
@ -194,8 +202,27 @@ class COMMON_EXPORT DfGraphConvertor {
void SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred, const OperatorPtr &src,
int index);
void UpdateTupleOutCache(void);
AnfNodePtr TransformConstOp(const CNodePtr &node, AnfNodePtr pred);
AnfNodePtr GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input);
void ConvertWhileNode(const CNodePtr &node);
void CacheWhileGraph(const CNodePtr &cnode);
void ConvertWhileBody(const AnfNodePtr &node);
std::shared_ptr<std::vector<Operator>> GetWhileSubGraphInput();
void BuildWhileSubGraph();
void ConvertWhileCond(const AnfNodePtr &node);
void ConvertWhileAfter(const AnfNodePtr &node);
void BuildWhileAfterSubGraph();
void GetCallNodeInputs(const CNodePtr &node);
std::vector<Operator> GetWhileBodyOutputs();
bool IsSubGraph() const { return graph_type_ == GraphType::kCond || graph_type_ == GraphType::kBody; }
bool IsAfterGraph() const { return graph_type_ == GraphType::kAfter; }
bool IsNormalGraph() const { return graph_type_ == GraphType::kNormal; }
bool IsBranchGraph() const { return graph_type_ == GraphType::kBranch; }
void SetParamIndexMap(const std::vector<AnfNodePtr> &graphs);
void SetWhileOutputHandle(const OperatorPtr &prev_while_op);
void GetWhileUsedInputIndex(const std::vector<AnfNodePtr> &graphs);
std::shared_ptr<AnfGraph> anf_graph_{nullptr};
std::shared_ptr<DfGraph> df_graph_{nullptr};
std::shared_ptr<DfGraph> init_graph_{nullptr};
@ -214,6 +241,7 @@ class COMMON_EXPORT DfGraphConvertor {
mindspore::HashMap<std::string, AnfNodePtr> params_;
mindspore::HashMap<std::string, OperatorPtr> vars_;
std::vector<std::pair<ge::Operator, std::string>> graph_outputs_;
std::vector<AnfNodePtr> graph_anf_outputs_;
std::vector<OperatorPtr> graph_const_inputs_;
std::vector<OperatorPtr> init_ops_;
std::vector<OperatorPtr> broadcast_ops_;
@ -223,6 +251,36 @@ class COMMON_EXPORT DfGraphConvertor {
bool training_ = false;
bool distribute_ = false;
bool use_inputs_ = false;
AnfNodePtr while_cond_node_ = nullptr;
mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> while_dfgraph_cache_;
CNodePtr cur_while_node_ = nullptr;
size_t cur_while_node_out_size_ = 0;
mindspore::HashMap<size_t, OutHandler> while_const_input_index_;
mindspore::HashMap<size_t, OutHandler> prev_while_const_input_index_;
mindspore::HashMap<size_t, size_t> prev_cond_to_while_out_index_;
mindspore::HashMap<OperatorPtr, std::shared_ptr<tensor::Tensor>> const_op_to_value_;
AnfNodePtr prev_while_node_ = nullptr;
size_t prev_while_node_out_size_ = 0;
mindspore::HashMap<AnfNodePtr, std::vector<AnfNodePtr>> while_graph_cache_;
mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<OutHandler>>> call_input_handle_cache_;
mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<OutHandler>>> while_output_handle_cache_;
AnfNodePtr call_node_in_while_body_ = nullptr;
GraphType graph_type_ = GraphType::kNormal;
ParamIndexMap body_cond_map_;
ParamIndexMap after_cond_map_;
ParamIndexMap prev_after_cond_map_;
mindspore::HashMap<size_t, OperatorPtr> subgraph_input_cache_;
std::set<size_t> while_used_input_index_;
std::set<size_t> prev_while_used_input_index_;
mindspore::HashMap<size_t, OutHandler> bypass_node_prev_handle_cache_;
mindspore::HashMap<size_t, OutHandler> bypass_node_handle_cache_;
size_t case_call_input_size_;
};
} // namespace transform
} // namespace mindspore

View File

@ -349,6 +349,7 @@ constexpr const char kNameTensorArray[] = "TensorArray";
constexpr const char kNameTensorArrayWrite[] = "TensorArrayWrite";
constexpr const char kNameTensorArrayGather[] = "TensorArrayGather";
constexpr const char kNameTensorMove[] = "TensorMove";
constexpr const char kNameWhile[] = "While";
class OpAdapterDesc;

View File

@ -234,6 +234,14 @@ class COMMON_EXPORT TransformUtil {
}
return dest;
}
/*
* Parameters:
* anf_name: [string] the anf node name
* Return
* [string] operator name
* */
static std::string NormOpName(const std::string &anf_name);
};
} // namespace transform
} // namespace mindspore

File diff suppressed because it is too large Load Diff

View File

@ -120,6 +120,17 @@ Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, int index,
return NOT_FOUND;
}
Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr<std::vector<DfGraph>> &subgraphs) {
MS_EXCEPTION_IF_NULL(op);
if (subgraph_map_.size() != subgraphs->size()) {
return INVALID_ARGUMENT;
}
for (size_t i = 0; i < subgraphs->size(); i++) {
subgraph_map_.at(i).set_subgraph(op, std::make_shared<DfGraph>((*subgraphs)[i]));
}
return SUCCESS;
}
Status OpAdapterImpl::SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) {
MS_EXCEPTION_IF_NULL(op);
MS_EXCEPTION_IF_NULL(input);
@ -349,7 +360,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
}
if (output_size != tuple_shp->shape().size()) {
MS_LOG(ERROR) << "output_map is not equal tuple_shape size";
MS_LOG(ERROR) << "output_map is not equal to tuple_shape size";
return FAILED;
}

View File

@ -33,6 +33,7 @@ class OpAdapterImpl {
const mindspore::HashMap<int, DynInputDesc> &dyn_input_map,
const mindspore::HashMap<int, OutputDesc> &output_map,
const mindspore::HashMap<int, DynOutputDesc> &dyn_output_map,
const mindspore::HashMap<int, SubGraphDesc> &subgraph_map,
const mindspore::HashMap<int, DynSubGraphDesc> &dyn_subgraph_map,
const mindspore::HashMap<std::string, AttrDesc> &attr_map,
const mindspore::HashMap<std::string, int> &enum_map,
@ -45,6 +46,7 @@ class OpAdapterImpl {
dyn_input_map_(dyn_input_map),
output_map_(output_map),
dyn_output_map_(dyn_output_map),
subgraph_map_(subgraph_map),
dyn_subgraph_map_(dyn_subgraph_map),
attr_map_(attr_map),
enum_map_(enum_map),
@ -65,6 +67,7 @@ class OpAdapterImpl {
Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim);
Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim);
OperatorPtr GenerateCustomOp(const AnfNodePtr anf);
Status SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr<std::vector<DfGraph>> &subgraphs);
Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches);
Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input);
Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input);
@ -100,6 +103,7 @@ class OpAdapterImpl {
const mindspore::HashMap<int, DynInputDesc> &dyn_input_map_;
const mindspore::HashMap<int, OutputDesc> &output_map_;
const mindspore::HashMap<int, DynOutputDesc> &dyn_output_map_;
const mindspore::HashMap<int, SubGraphDesc> &subgraph_map_;
const mindspore::HashMap<int, DynSubGraphDesc> &dyn_subgraph_map_;
const mindspore::HashMap<std::string, AttrDesc> &attr_map_;
const mindspore::HashMap<std::string, int> &enum_map_;
@ -116,14 +120,14 @@ class OpAdapter : public BaseOpAdapter {
public:
using OpType = T;
OpAdapter()
: impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_,
: impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, subgraph_map_,
dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, &cus_input_map_,
&cus_output_map_, &extra_attr_, &name_counts_, this)) {
MS_EXCEPTION_IF_NULL(impl_);
}
explicit OpAdapter(const ExtraAttr &extra_attr)
: extra_attr_(extra_attr),
impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_,
impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, subgraph_map_,
dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, &cus_input_map_,
&cus_output_map_, &extra_attr_, &name_counts_, this)) {
MS_EXCEPTION_IF_NULL(impl_);
@ -148,8 +152,17 @@ class OpAdapter : public BaseOpAdapter {
// There are duplicate names in ANF graph, do not assign ANF node name to GE
// GE will generate unique name automatically
if (anf != nullptr && anf->fullname_with_scope() != "") {
MS_LOG(DEBUG) << anf->fullname_with_scope();
op = std::make_shared<OpType>(anf->fullname_with_scope());
auto name = anf->fullname_with_scope();
MS_LOG(DEBUG) << name;
string user_data_key = "subgraph_node";
if (anf->has_user_data(user_data_key) && *(anf->user_data<bool>(user_data_key))) {
auto norm_name = TransformUtil::NormOpName(name);
if (norm_name != name) {
MS_LOG(DEBUG) << "normalize anf name : " << name << " to operator name : " << norm_name;
}
name = norm_name;
}
op = std::make_shared<OpType>(name);
} else {
MS_LOG(DEBUG) << "no fullname_with_scope";
op = std::make_shared<OpType>();
@ -169,6 +182,28 @@ class OpAdapter : public BaseOpAdapter {
return op;
}
OperatorPtr GenerateDynamicOutputOp(const AnfNodePtr &anf) {
OperatorPtr op = nullptr;
// There are duplicate names in ANF graph, do not assign ANF node name to GE
// GE will generate unique name automatically
if (anf != nullptr && anf->fullname_with_scope() != "") {
MS_LOG(DEBUG) << anf->fullname_with_scope();
op = std::make_shared<OpType>(anf->fullname_with_scope());
} else {
MS_LOG(DEBUG) << "no fullname_with_scope";
op = std::make_shared<OpType>();
}
return op;
}
void setDynamicOutputNum(const OperatorPtr &op, size_t dyn_output_size) override {
// set dynamic output num if op use DYNAMIC_OUTPUT
if ((op != nullptr) && (!dyn_output_map_.empty())) {
MS_LOG(DEBUG) << "create_dyn_output for node:" << op->GetName() << ", num:" << dyn_output_size;
dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(dyn_output_size));
}
}
OperatorPtr generate(const AnfNodePtr &anf) override {
OperatorPtr op = nullptr;
if (IsCustomCNode(anf)) {
@ -184,12 +219,30 @@ class OpAdapter : public BaseOpAdapter {
OperatorPtr generate(const std::string &op_name) override { return std::make_shared<OpType>(op_name); }
OperatorPtr generateDynOutputOp(const AnfNodePtr &anf) override {
OperatorPtr op = nullptr;
op = GenerateDynamicOutputOp(anf);
if (op == nullptr) {
MS_LOG(EXCEPTION) << "Can not generate op for " << anf->fullname_with_scope();
}
return op;
}
const mindspore::HashMap<int, InputDesc> &getInputMap() override { return input_map_; }
const mindspore::HashMap<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; }
const mindspore::HashMap<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; }
const mindspore::HashMap<int, SubGraphDesc> &getSubgraphMap() override { return subgraph_map_; }
const mindspore::HashMap<int, OutputDesc> &getOutputMap() override { return output_map_; }
const mindspore::HashMap<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; }
Status SetOpSubgraphFunc(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) {
return impl_->SetOpSubgraphFunc(op, subgraphs);
}
int setSubgraph(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) override {
return static_cast<int>(SetOpSubgraphFunc(op, subgraphs));
}
Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) {
return impl_->SetOpSubgraphFunc(op, index, branches);
}
@ -428,6 +481,7 @@ class OpAdapter : public BaseOpAdapter {
static const mindspore::HashMap<int, DynInputDesc> dyn_input_map_;
static const mindspore::HashMap<int, OutputDesc> output_map_;
static const mindspore::HashMap<int, DynOutputDesc> dyn_output_map_;
static const mindspore::HashMap<int, SubGraphDesc> subgraph_map_;
static const mindspore::HashMap<int, DynSubGraphDesc> dyn_subgraph_map_;
static const mindspore::HashMap<std::string, AttrDesc> attr_map_;
static const mindspore::HashMap<std::string, int> enum_map_;
@ -449,6 +503,8 @@ const mindspore::HashMap<int, OutputDesc> OpAdapter<T>::output_map_;
template <typename T>
const mindspore::HashMap<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
template <typename T>
const mindspore::HashMap<int, SubGraphDesc> OpAdapter<T>::subgraph_map_;
template <typename T>
const mindspore::HashMap<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
template <typename T>
const mindspore::HashMap<std::string, AttrDesc> OpAdapter<T>::attr_map_;
@ -464,5 +520,4 @@ mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> OpAdapter<
// specialization for method
} // namespace transform
} // namespace mindspore
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_

View File

@ -63,7 +63,10 @@ using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr
using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>;
using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>;
using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
using UpdateDynOutputDescFunc = std::function<void(OperatorPtr, unsigned int, GeTensorDesc)>;
using SubGraphFunc = std::function<void(OperatorPtr, DfGraphPtr)>;
using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>;
using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>;
struct AttrDesc {
@ -85,6 +88,11 @@ struct DynInputDesc {
DynInputHandleFunc set_handle;
};
struct SubGraphDesc {
std::string name;
SubGraphFunc set_subgraph;
};
struct DynSubGraphDesc {
std::string name;
CreateDynSubGraphFunc create_dyn_subgraph;
@ -99,6 +107,7 @@ struct OutputDesc {
struct DynOutputDesc {
std::string name;
CreateDynOutputOpFunc create_dyn_output;
UpdateDynOutputDescFunc update_dyn_output_desc;
};
class BaseOpAdapter {
@ -106,6 +115,9 @@ class BaseOpAdapter {
virtual ~BaseOpAdapter() {}
virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); }
virtual OperatorPtr generateDynOutputOp(const AnfNodePtr &anf) { return nullptr; }
virtual void setDynamicOutputNum(const OperatorPtr &op, size_t dyn_output_size) { return; }
virtual int setSubgraph(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) = 0;
virtual int setSubgraph(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) = 0;
virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
@ -130,6 +142,7 @@ class BaseOpAdapter {
virtual const mindspore::HashMap<unsigned int, AttrDesc> &getInputAttrMap() = 0;
virtual const mindspore::HashMap<int, DynInputDesc> &getDynInputMap() = 0;
virtual const mindspore::HashMap<int, OutputDesc> &getOutputMap() = 0;
virtual const mindspore::HashMap<int, SubGraphDesc> &getSubgraphMap() = 0;
virtual const mindspore::HashMap<int, DynSubGraphDesc> &getDynSubgraphMap() = 0;
void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }

View File

@ -23,5 +23,12 @@ DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}};
ATTR_MAP(Case) = EMPTY_ATTR_MAP;
DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}};
DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}};
REG_ADPT_DESC(Case, kNameCase, ADPT_DESC(Case))
REG_ADPT_DESC(Case, kNameCase, ADPT_DESC(Case));
// While
DYN_INPUT_MAP(While) = {{1, DYN_INPUT_DESC(input)}};
ATTR_MAP(While) = {{"parallel_iterations", ATTR_DESC(parallel_iterations, AnyTraits<int32_t>())}};
DYN_OUTPUT_MAP(While) = {{0, DYN_OUTPUT_DESC(output)}};
SUBGRAPH_MAP(While) = {{0, SUBGRAPH_DESC(cond)}, {1, SUBGRAPH_DESC(body)}};
REG_ADPT_DESC(While, kNameWhile, ADPT_DESC(While));
} // namespace mindspore::transform

View File

@ -27,5 +27,11 @@ DECLARE_OP_ADAPTER(Case)
DECLARE_OP_USE_DYN_INPUT(Case)
DECLARE_OP_USE_DYN_SUBGRAPH(Case)
DECLARE_OP_USE_DYN_OUTPUT(Case)
DECLARE_OP_TYPE(While)
DECLARE_OP_ATTR(While)
DECLARE_OP_USE_DYN_INPUT(While)
DECLARE_OP_USE_SUBGRAPH(While)
DECLARE_OP_USE_DYN_OUTPUT(While)
} // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_FUNCTIONAL_OPS_DECLARE_H_

View File

@ -33,10 +33,20 @@ namespace mindspore::transform {
template <> \
const mindspore::HashMap<std::string, AttrDesc> OpAdapter<T>::attr_map_;
#define DECLARE_OP_TYPE(T) using T = ge::op::T;
#define DECLARE_OP_ATTR(T) \
template <> \
const mindspore::HashMap<std::string, AttrDesc> OpAdapter<T>::attr_map_;
#define DECLARE_OP_USE_OUTPUT(T) \
template <> \
const mindspore::HashMap<int, OutputDesc> OpAdapter<T>::output_map_;
#define DECLARE_OP_USE_SUBGRAPH(T) \
template <> \
const mindspore::HashMap<int, SubGraphDesc> OpAdapter<T>::subgraph_map_;
#define DECLARE_OP_USE_ENUM(T) \
template <> \
const mindspore::HashMap<std::string, int> OpAdapter<T>::enum_map_{};
@ -57,6 +67,18 @@ namespace mindspore::transform {
template <> \
const mindspore::HashMap<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
#define SUBGRAPH_MAP(T) \
template <> \
const mindspore::HashMap<int, SubGraphDesc> OpAdapter<T>::subgraph_map_
#define SUBGRAPH_DESC(name) \
{ \
#name, \
[](const OperatorPtr op, const DfGraphPtr graph) { \
auto p = std::static_pointer_cast<OpType>(op); \
(void)p->set_subgraph_builder_##name([graph](){return *graph;}); \
} \
}
#define INPUT_MAP(T) \
template <> \
const mindspore::HashMap<int, InputDesc> OpAdapter<T>::input_map_
@ -147,13 +169,17 @@ namespace mindspore::transform {
template <> \
const mindspore::HashMap<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_
#define DYN_OUTPUT_DESC(name) \
{ \
#name, \
[](const OperatorPtr op, unsigned int num) { \
auto p = std::static_pointer_cast<OpType>(op); \
(void)p->create_dynamic_output_##name(num); \
} \
#define DYN_OUTPUT_DESC(name) \
{ \
#name, \
[](const OperatorPtr op, unsigned int num) { \
auto p = std::static_pointer_cast<OpType>(op); \
(void)p->create_dynamic_output_##name(num); \
}, \
[](const OperatorPtr op, uint32_t index, const GeTensorDesc tensor_desc) { \
auto p = std::static_pointer_cast<OpType>(op); \
(void)p->UpdateDynamicOutputDesc(#name, index, tensor_desc); \
} \
}
#define ADPT_DESC_ONE(T) std::make_shared<OpAdapterDesc>(std::make_shared<OpAdapter<T>>())

View File

@ -452,5 +452,16 @@ std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
}
return ret;
}
std::string TransformUtil::NormOpName(const std::string &anf_name) {
std::string str = anf_name.substr(anf_name.rfind("/") + 1);
std::string ret;
for (const auto &c : str) {
if (std::isalnum(c) || c == '_' || c == '-') {
ret += c;
}
}
return ret;
}
} // namespace transform
} // namespace mindspore

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
""" test_ascend_control_sink """
import os
import pytest
import numpy as np
import mindspore.context as context
@ -191,6 +192,45 @@ class NotOperation(nn.Cell):
return not x_sum
class SimpleCell(nn.Cell):
def __init__(self, i):
super().__init__()
self.i = i
def construct(self, x):
return self.i * x
class CellListInWhileByWhile(nn.Cell):
def __init__(self):
super().__init__()
self.cell_list = nn.CellList()
self.cell_list.append(SimpleCell(4))
self.cell_list.append(SimpleCell(5))
self.cell_list.append(SimpleCell(6))
def construct(self, t, x):
out = t
while x < 3:
out += 4
x += 1
x = 0
while x < 3:
add = self.cell_list[x](t)
out = out + add
x += 1
return out
def cell_list_in_while_by_while():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = CellListInWhileByWhile()
t = Tensor(10, mstype.int32)
x = Tensor(0, mstype.int32)
out = net(t, x)
return out
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -347,3 +387,19 @@ def test_control_flow_ref():
input_x = Tensor(6, ms.float32)
out = net(input_x)
assert out == 4
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_cell_list_in_while_by_while_ge():
"""
Feature: Control flow(while and case) implement in ge
Description: run the whole graph sink in ascend in ge backend
Expectation: success
"""
os.environ['MS_ENABLE_GE'] = "1"
out = cell_list_in_while_by_while()
assert out == Tensor(172, mstype.int32)
del os.environ['MS_ENABLE_GE']

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import pytest
@ -50,6 +50,33 @@ class CaseNet(nn.Cell):
return x
class SimpleCell(nn.Cell):
def __init__(self, i):
super().__init__()
self.i = i
def construct(self, x):
return self.i * x
class CellInList(nn.Cell):
def __init__(self):
super().__init__()
self.cell_list = nn.CellList()
self.cell_list.append(SimpleCell(4))
self.cell_list.append(SimpleCell(5))
self.cell_list.append(SimpleCell(6))
def construct(self, t, x):
out = t
while x < 3:
add = self.cell_list[x](t)
out = out + add
x += 1
return out
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -79,31 +106,6 @@ def test_cell_in_list():
Description: test recursive switch layer.
Expectation: success if grad and output are correct.
"""
class TestCell(nn.Cell):
def __init__(self, i):
super().__init__()
self.i = i
def construct(self, x):
return self.i * x
class CellInList(nn.Cell):
def __init__(self):
super().__init__()
self.cell_list = nn.CellList()
self.cell_list.append(TestCell(4))
self.cell_list.append(TestCell(5))
self.cell_list.append(TestCell(6))
def construct(self, t, x):
out = t
while x < 3:
add = self.cell_list[x](t)
out = out + add
x += 1
return out
net = CellInList()
t = Tensor(10, mstype.int32)
x = Tensor(0, mstype.int32)
@ -113,3 +115,22 @@ def test_cell_in_list():
assert out == Tensor(160, mstype.int32)
assert grad_out == Tensor(16, mstype.int32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_cell_in_list_ge():
"""
Feature: Switch layer in while in ge backend.
Description: test recursive switch layer in ge backend.
Expectation: success.
"""
os.environ['MS_ENABLE_GE'] = "1"
net = CellInList()
t = Tensor(20, mstype.int32)
x = Tensor(0, mstype.int32)
out = net(t, x)
del os.environ['MS_ENABLE_GE']
assert out == Tensor(320, mstype.int32)