!44526 support dynamic input for GE

Merge pull request !44526 from xulei/ge_dynamic_shape
This commit is contained in:
i-robot 2022-11-16 02:06:52 +00:00 committed by Gitee
commit d147d1f202
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
28 changed files with 707 additions and 78 deletions

View File

@ -25,6 +25,20 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
template <typename Map, typename K = typename Map::key_type, typename V = typename Map::mapped_type>
std::string MapToString(const Map &value) {
std::stringstream buffer;
buffer << "{";
for (auto it = value.begin(); it != value.end(); it++) {
if (it != value.begin()) {
buffer << ", ";
}
buffer << it->first << ": " << it->second;
}
buffer << "}";
return buffer.str();
}
std::string GetErrorMessage(bool add_title = false); std::string GetErrorMessage(bool add_title = false);
std::string GetWarningMessage(); std::string GetWarningMessage();
void SetErrorManagerContext(); void SetErrorManagerContext();

View File

@ -35,6 +35,7 @@
#include "runtime/hardware/device_context_manager.h" #include "runtime/hardware/device_context_manager.h"
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h" #include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
#include "plugin/device/ascend/optimizer/ge_optimization.h" #include "plugin/device/ascend/optimizer/ge_optimization.h"
#include "plugin/device/ascend/hal/common/ascend_utils.h"
#include "runtime/config.h" #include "runtime/config.h"
#include "runtime/dev.h" #include "runtime/dev.h"
#include "distributed/init.h" #include "distributed/init.h"
@ -43,6 +44,8 @@ namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
namespace { namespace {
using mindspore::transform::OptionMap;
constexpr auto kMindsporeDumpConfig = "MINDSPORE_DUMP_CONFIG"; constexpr auto kMindsporeDumpConfig = "MINDSPORE_DUMP_CONFIG";
constexpr char kGeDumpMode[3][7] = {"all", "input", "output"}; constexpr char kGeDumpMode[3][7] = {"all", "input", "output"};
@ -101,6 +104,38 @@ transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
return res; return res;
} }
std::string ShapesToString(const ShapeArray &shapes) {
std::stringstream buffer;
for (size_t i = 0; i < shapes.size(); ++i) {
if (i != 0) {
buffer << ",";
}
buffer << "[";
const auto &shape = shapes[i];
for (size_t j = 0; j < shape.size(); ++j) {
if (j != 0) {
buffer << ",";
}
buffer << shape[j];
}
buffer << "]";
}
return buffer.str();
}
OptionMap GetComputeGraphOptions(const ShapeArray &input_shapes, bool is_dynamic_shape) {
OptionMap options{};
if (common::GetEnv("GE_TRAIN") == "1") {
(void)options.emplace("ge.exec.variable_acc", "1");
}
if (!is_dynamic_shape) {
return options;
}
(void)options.emplace("ge.exec.dynamicGraphExecuteMode", "dynamic_execute");
(void)options.emplace("ge.exec.dataInputsShapeRange", ShapesToString(input_shapes));
return options;
}
bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) { bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) {
MS_EXCEPTION_IF_NULL(anf_graph); MS_EXCEPTION_IF_NULL(anf_graph);
auto converter = transform::NewConverter(anf_graph); auto converter = transform::NewConverter(anf_graph);
@ -121,11 +156,9 @@ bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &
std::string graph_name = anf_graph->ToString(); std::string graph_name = anf_graph->ToString();
std::string init_graph = "init_subgraph." + graph_name; std::string init_graph = "init_subgraph." + graph_name;
std::string checkpoint_name = "save." + GetGraphName(anf_graph); std::string checkpoint_name = "save." + GetGraphName(anf_graph);
if (common::GetEnv("GE_TRAIN") == "1") { const auto options = GetComputeGraphOptions(converter->input_shapes(), converter->dynamic_shape_inputs());
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), {{"ge.exec.variable_acc", "1"}}); MS_LOG(INFO) << "Set options of compute graph: " << graph_name << " to " << MapToString(options);
} else { (void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), options);
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter));
}
(void)transform::AddGraph(init_graph, transform::GetInitGraph(converter)); (void)transform::AddGraph(init_graph, transform::GetInitGraph(converter));
(void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter)); (void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter));
@ -488,7 +521,7 @@ void GeDeviceContext::InitGe(const std::shared_ptr<MsContext> &inst_context) {
{ {
// Release GIL before calling into (potentially long-running) C++ code // Release GIL before calling into (potentially long-running) C++ code
mindspore::ScopedLongRunning long_running; mindspore::ScopedLongRunning long_running;
if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) { if (::ge::GEInitialize(ge_options) != ::ge::GRAPH_SUCCESS) {
MS_LOG(EXCEPTION) << "Initialize GE failed!"; MS_LOG(EXCEPTION) << "Initialize GE failed!";
} }
} }
@ -670,7 +703,7 @@ bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context)
std::string exName(abi::__cxa_current_exception_type()->name()); std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName; MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
} }
if (ge::GEFinalize() != ge::GRAPH_SUCCESS) { if (::ge::GEFinalize() != ::ge::GRAPH_SUCCESS) {
MS_LOG(WARNING) << "Finalize GE failed!"; MS_LOG(WARNING) << "Finalize GE failed!";
} }
inst_context->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false); inst_context->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);

View File

@ -20,6 +20,7 @@
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include <stack> #include <stack>
#include <unordered_set>
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "mindspore/core/ops/core_ops.h" #include "mindspore/core/ops/core_ops.h"
#include "utils/anf_utils.h" #include "utils/anf_utils.h"
@ -392,14 +393,16 @@ DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) {
return *this; return *this;
} }
// Processing input with MakeDatasetHandler // Processing input with MakeDatasetHandler and check whether input is dynamic
for (auto &it : anf_graph_->inputs()) { for (auto &it : anf_graph_->inputs()) {
auto op_itor = op_cache_.find(it.get()); // converted node auto op_itor = op_cache_.find(it.get()); // converted node
if (it->isa<Parameter>() && op_itor != op_cache_.end()) { if (it->isa<Parameter>() && op_itor != op_cache_.end()) {
string name = std::static_pointer_cast<Parameter>(it)->name(); const auto &param = std::static_pointer_cast<Parameter>(it);
string name = param->name();
auto tensor_itor = tensors.find(name); // in init value map auto tensor_itor = tensors.find(name); // in init value map
if (tensor_itor == tensors.end()) { if (tensor_itor == tensors.end()) {
DfGraphConvertor::MakeDatasetHandler(name, input_idx, it); MakeDatasetHandler(name, input_idx, it);
AddGraphDynamicInput(param);
input_idx++; input_idx++;
} }
} }
@ -1941,6 +1944,19 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
DfGraphConvertor::SetOpInput(adpt, cnode); DfGraphConvertor::SetOpInput(adpt, cnode);
} }
void DfGraphConvertor::AddGraphDynamicInput(const ParameterPtr &param) {
MS_EXCEPTION_IF_NULL(param);
const auto &base_shape = param->Shape();
MS_EXCEPTION_IF_NULL(base_shape);
const auto &shape = base_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
const auto &sv = shape->shape();
if (std::any_of(sv.cbegin(), sv.cend(), [](const auto e) { return e == -1; })) {
dynamic_shape_inputs_ = true;
}
(void)input_shapes_.emplace_back(sv);
}
std::string DfGraphConvertor::GetGNodeName(const ::ge::GNode &node) const { std::string DfGraphConvertor::GetGNodeName(const ::ge::GNode &node) const {
::ge::AscendString name; ::ge::AscendString name;
auto ret = node.GetName(name); auto ret = node.GetName(name);
@ -2224,7 +2240,7 @@ OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) {
return nullptr; return nullptr;
} }
void DfGraphConvertor::ConvertTopK(const CNodePtr node) { void DfGraphConvertor::ConvertTopK(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32."; MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
auto value_ptr = node->input(2)->cast<ValueNodePtr>(); auto value_ptr = node->input(2)->cast<ValueNodePtr>();
@ -2372,7 +2388,7 @@ std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) const {
return cur_value; return cur_value;
} }
void DfGraphConvertor::ConvertReshape(const CNodePtr node) { void DfGraphConvertor::ConvertReshape(const CNodePtr &node) {
MS_LOG(INFO) << "Convert the second input of reshape to op attr."; MS_LOG(INFO) << "Convert the second input of reshape to op attr.";
const auto kInputNum = 3; const auto kInputNum = 3;
if (node->size() < kInputNum) { if (node->size() < kInputNum) {
@ -2386,10 +2402,7 @@ void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
auto op = adpt->generate(node); auto op = adpt->generate(node);
MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(op);
// get shape form attr // get shape form attr
auto value_node = node->input(0)->cast<ValueNodePtr>(); auto primitive = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(value_node);
MS_EXCEPTION_IF_NULL(value_node->value());
auto primitive = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto value = primitive->GetAttr("shape"); auto value = primitive->GetAttr("shape");
std::vector<int64_t> list; std::vector<int64_t> list;
@ -2399,7 +2412,34 @@ void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
op_cache_[node.get()] = op; op_cache_[node.get()] = op;
} }
void DfGraphConvertor::ConvertAllReduce(const CNodePtr node) { void DfGraphConvertor::ConvertDynamicStitch(const CNodePtr &node) {
MS_LOG(INFO) << "Convert and set 'N' attr of DynamicStitch.";
OpAdapterPtr adpt = FindAdapter(node, training_);
if (adpt == nullptr) {
return;
}
auto op = adpt->generate(node);
MS_EXCEPTION_IF_NULL(op);
int64_t input_length = 0;
auto indices = node->input(1);
MS_EXCEPTION_IF_NULL(indices);
if (indices->isa<CNode>()) {
input_length = SizeToLong(indices->cast<CNodePtr>()->size()) - 1;
} else if (IsValueNode<ValueSequence>(indices)) {
const auto tuple = GetValueNode<ValueSequencePtr>(indices);
MS_EXCEPTION_IF_NULL(tuple);
input_length = SizeToLong(tuple->size());
} else {
MS_LOG(EXCEPTION) << "Input 1 of DynamicStitch is neither CNode nor ValueNode contains ValueSequence, but "
<< indices->ToString() << ", can not set 'N' attr.";
}
(void)op->SetAttr("N", input_length);
MS_LOG(INFO) << "Set 'N' attr of DynamicStitch to " << input_length;
op_cache_[node.get()] = op;
}
void DfGraphConvertor::ConvertAllReduce(const CNodePtr &node) {
MS_LOG(INFO) << "Add AllReduce fusion_id"; MS_LOG(INFO) << "Add AllReduce fusion_id";
OpAdapterPtr adpt = FindAdapter(node, training_); OpAdapterPtr adpt = FindAdapter(node, training_);
if (adpt == nullptr) { if (adpt == nullptr) {
@ -2408,10 +2448,7 @@ void DfGraphConvertor::ConvertAllReduce(const CNodePtr node) {
auto op = adpt->generate(node); auto op = adpt->generate(node);
MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(op);
// get shape form attr // get shape form attr
auto value_node = node->input(0)->cast<ValueNodePtr>(); auto primitive = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(value_node);
MS_EXCEPTION_IF_NULL(value_node->value());
auto primitive = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto fusion_value = primitive->GetAttr("fusion"); auto fusion_value = primitive->GetAttr("fusion");
auto fusion = GetValue<int64_t>(fusion_value); auto fusion = GetValue<int64_t>(fusion_value);
@ -2425,7 +2462,8 @@ void DfGraphConvertor::ConvertAllReduce(const CNodePtr node) {
op_cache_[node.get()] = op; op_cache_[node.get()] = op;
} }
void DfGraphConvertor::ConvertConv2D(const CNodePtr node) { void DfGraphConvertor::ConvertConv2D(const CNodePtr &node) {
MS_LOG(INFO) << "Convert and set 'padding' attr for Conv2D-like op.";
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
OpAdapterPtr adpt = FindAdapter(node, training_); OpAdapterPtr adpt = FindAdapter(node, training_);
if (adpt == nullptr) { if (adpt == nullptr) {
@ -2433,20 +2471,33 @@ void DfGraphConvertor::ConvertConv2D(const CNodePtr node) {
} }
auto op = adpt->generate(node); auto op = adpt->generate(node);
MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(op);
auto value_node = node->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
MS_EXCEPTION_IF_NULL(value_node->value());
auto primitive = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(primitive);
auto value = primitive->GetAttr("padding");
if (value != nullptr) {
std::string pad_mode = GetValue<std::string>(value);
(void)op->SetAttr("padding", pad_mode);
}
op_cache_[node.get()] = op; op_cache_[node.get()] = op;
auto primitive = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
std::string pad_mode;
if (auto value = primitive->GetAttr("padding"); value != nullptr) {
pad_mode = GetValue<std::string>(value);
} else if (auto value = primitive->GetAttr("pad_mode"); value != nullptr) {
// Get 'pad_mode' attr and set it to 'padding' attr for ge
const mindspore::HashMap<int64_t, std::string> pad_mode_map{{1, "SAME"}, {2, "VALID"}};
if (value->isa<StringImm>()) {
pad_mode = GetValue<std::string>(value);
(void)std::transform(pad_mode.cbegin(), pad_mode.cend(), pad_mode.begin(), toupper);
} else if (auto it = pad_mode_map.find(GetValue<int64_t>(value)); it != pad_mode_map.cend()) {
// 'pad_mode' attr could be an enumeration
pad_mode = it->second;
} else {
return;
}
} else {
MS_LOG(INFO) << "Node: " << node->fullname_with_scope() << " has no 'padding' or 'pad_mode' attr";
return;
}
MS_LOG(INFO) << "Set 'padding' attr of node: " << node->fullname_with_scope() << " to " << pad_mode;
(void)op->SetAttr("padding", pad_mode);
} }
void DfGraphConvertor::ConvertOCRRecPreHandle(const CNodePtr node) { void DfGraphConvertor::ConvertOCRRecPreHandle(const CNodePtr &node) {
MS_LOG(INFO) << "Add OCRRecognitionPreHandle _op_max_shape attr"; MS_LOG(INFO) << "Add OCRRecognitionPreHandle _op_max_shape attr";
OpAdapterPtr adpt = FindAdapter(node, training_); OpAdapterPtr adpt = FindAdapter(node, training_);
if (adpt == nullptr) { if (adpt == nullptr) {
@ -2455,10 +2506,7 @@ void DfGraphConvertor::ConvertOCRRecPreHandle(const CNodePtr node) {
auto op = adpt->generate(node); auto op = adpt->generate(node);
MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(op);
// get shape form attr // get shape form attr
auto value_node = node->input(0)->cast<ValueNodePtr>(); auto primitive = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(value_node);
MS_EXCEPTION_IF_NULL(value_node->value());
auto primitive = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto value = primitive->GetAttr("_op_max_shape"); auto value = primitive->GetAttr("_op_max_shape");
if (value == nullptr) { if (value == nullptr) {
@ -2517,33 +2565,26 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
return false; return false;
} }
// Convert TopK second input from int64 to int32. const mindspore::HashMap<std::string, std::function<void(decltype(this), const CNodePtr &)>>
if (name == prim::kPrimTopK->name()) { auxiliary_node_converters{
ConvertTopK(node); // Convert TopK second input from int64 to int32.
return true; {prim::kPrimTopK->name(), &DfGraphConvertor::ConvertTopK},
} // Convert Reshape add const input to attr(shape)
{prim::kPrimReshape->name(), &DfGraphConvertor::ConvertReshape},
{prim::kPrimOCRRecognitionPreHandle->name(), &DfGraphConvertor::ConvertOCRRecPreHandle},
// Add attr 'pad_mode' to Conv2D-like op
{prim::kPrimConv2D->name(), &DfGraphConvertor::ConvertConv2D},
{prim::kPrimDepthwiseConv2dNative->name(), &DfGraphConvertor::ConvertConv2D},
{kNameConv2DBackpropInputV2, &DfGraphConvertor::ConvertConv2D},
{prim::kPrimConv2DBackpropInput->name(), &DfGraphConvertor::ConvertConv2D},
{prim::kPrimConv2DBackpropFilter->name(), &DfGraphConvertor::ConvertConv2D},
// Add attr 'N' to DynamicStitch
{prim::kPrimDynamicStitch->name(), &DfGraphConvertor::ConvertDynamicStitch},
{prim::kPrimAllReduce->name(), &DfGraphConvertor::ConvertAllReduce},
};
// Convert Reshape add const input to attr(shape) if (const auto it = auxiliary_node_converters.find(name); it != auxiliary_node_converters.cend()) {
if (name == prim::kPrimReshape->name()) { it->second(this, node);
ConvertReshape(node);
return true;
}
if (name == prim::kPrimOCRRecognitionPreHandle->name()) {
ConvertOCRRecPreHandle(node);
return true;
}
// Add attr pad mode to Conv2D
if (name == prim::kPrimConv2D->name() || name == prim::kPrimDepthwiseConv2dNative->name() ||
name == kNameConv2DBackpropInputV2) {
ConvertConv2D(node);
return true;
}
if (name == prim::kPrimAllReduce->name()) {
ConvertAllReduce(node);
return true;
} }
return true; return true;

View File

@ -156,6 +156,8 @@ class DfGraphConvertor {
bool is_training() const { return training_; } bool is_training() const { return training_; }
void set_training(bool is_training) { training_ = is_training; } void set_training(bool is_training) { training_ = is_training; }
bool dynamic_shape_inputs() const { return dynamic_shape_inputs_; }
std::vector<ShapeVector> input_shapes() { return input_shapes_; }
protected: protected:
void InitLoopVar(std::vector<::ge::Operator> *init_input); void InitLoopVar(std::vector<::ge::Operator> *init_input);
@ -178,15 +180,16 @@ class DfGraphConvertor {
OperatorPtr ConvertValueNode(ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node);
void SaveParamFormat(CNodePtr node); void SaveParamFormat(CNodePtr node);
void GetBranchNodeInput(const CNodePtr node); void GetBranchNodeInput(const CNodePtr node);
void ConvertTopK(const CNodePtr node); void ConvertTopK(const CNodePtr &node);
void ConvertResizeBilinear(const FuncGraphPtr anf_graph) const; void ConvertResizeBilinear(const FuncGraphPtr anf_graph) const;
void ConvertSpaceBatchNd(const FuncGraphPtr anf_graph) const; void ConvertSpaceBatchNd(const FuncGraphPtr anf_graph) const;
void ConvertTile(const FuncGraphPtr anf_graph) const; void ConvertTile(const FuncGraphPtr anf_graph) const;
AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const; AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const;
void ConvertReshape(const CNodePtr node); void ConvertReshape(const CNodePtr &node);
void ConvertAllReduce(const CNodePtr node); void ConvertAllReduce(const CNodePtr &node);
void ConvertOCRRecPreHandle(const CNodePtr node); void ConvertOCRRecPreHandle(const CNodePtr &node);
void ConvertConv2D(const CNodePtr node); void ConvertConv2D(const CNodePtr &node);
void ConvertDynamicStitch(const CNodePtr &node);
std::vector<int64_t> CastToInt(const ValuePtr &value) const; std::vector<int64_t> CastToInt(const ValuePtr &value) const;
bool CheckCNode(const std::string &name, const CNodePtr node); bool CheckCNode(const std::string &name, const CNodePtr node);
void SetNodeInput(AnfNodePtr node); void SetNodeInput(AnfNodePtr node);
@ -198,6 +201,7 @@ class DfGraphConvertor {
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
void UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; void UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
void AddGraphConstInput(const OperatorPtr &op); void AddGraphConstInput(const OperatorPtr &op);
void AddGraphDynamicInput(const ParameterPtr &param);
AnfNodePtr ParseLoadInput(const CNodePtr &cnode) const; AnfNodePtr ParseLoadInput(const CNodePtr &cnode) const;
void SetGraphInputs(std::vector<Operator> *inputs); void SetGraphInputs(std::vector<Operator> *inputs);
void TransformConstOp(const CNodePtr &node, const AnfNodePtr &pred); void TransformConstOp(const CNodePtr &node, const AnfNodePtr &pred);
@ -264,6 +268,7 @@ class DfGraphConvertor {
std::vector<OperatorPtr> init_ops_; std::vector<OperatorPtr> init_ops_;
std::vector<OperatorPtr> broadcast_ops_; std::vector<OperatorPtr> broadcast_ops_;
std::vector<AnfNodePtr> inputs_; std::vector<AnfNodePtr> inputs_;
ShapeArray input_shapes_;
OperatorPtr dataset_iter_getnext_; OperatorPtr dataset_iter_getnext_;
OperatorPtr queue_data_; OperatorPtr queue_data_;
OperatorPtr get_next_from_queue_; OperatorPtr get_next_from_queue_;
@ -271,6 +276,7 @@ class DfGraphConvertor {
bool training_ = false; bool training_ = false;
bool distribute_ = false; bool distribute_ = false;
bool use_inputs_ = false; bool use_inputs_ = false;
bool dynamic_shape_inputs_ = false;
AnfNodePtr while_cond_node_ = nullptr; AnfNodePtr while_cond_node_ = nullptr;
mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> while_dfgraph_cache_; mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> while_dfgraph_cache_;

View File

@ -396,6 +396,9 @@ constexpr const char kNameTransData[] = "TransData";
constexpr const char kNameSend[] = "Send"; constexpr const char kNameSend[] = "Send";
constexpr const char kNameReceive[] = "Receive"; constexpr const char kNameReceive[] = "Receive";
constexpr const char kNameIndexAdd[] = "IndexAdd"; constexpr const char kNameIndexAdd[] = "IndexAdd";
constexpr const char kNameUnique[] = "Unique";
constexpr const char kNameDynamicBroadcastGradientArgs[] = "DynamicBroadcastGradientArgs";
constexpr const char kNameDynamicStitch[] = "DynamicStitch";
class OpAdapterDesc; class OpAdapterDesc;

View File

@ -160,5 +160,15 @@ REG_ADPT_DESC(IdentityNMakeTuple, kNameMakeTuple, ADPT_DESC(IdentityN))
REG_ADPT_DESC(IdentityNDepend, kNameDepend, ADPT_DESC(IdentityN)) REG_ADPT_DESC(IdentityNDepend, kNameDepend, ADPT_DESC(IdentityN))
REG_ADPT_DESC(IdentityNReturn, kNameReturn, ADPT_DESC(IdentityN)) REG_ADPT_DESC(IdentityNReturn, kNameReturn, ADPT_DESC(IdentityN))
// IdentityN // Unique
INPUT_MAP(Unique) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Unique) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Unique) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(idx)}};
REG_ADPT_DESC(Unique, kNameUnique, ADPT_DESC(Unique))
// BroadcastGradientArgs
INPUT_MAP(BroadcastGradientArgs) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(BroadcastGradientArgs) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BroadcastGradientArgs) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}};
REG_ADPT_DESC(BroadcastGradientArgs, kNameDynamicBroadcastGradientArgs, ADPT_DESC(BroadcastGradientArgs))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -84,5 +84,11 @@ DECLARE_OP_USE_OUTPUT(Identity)
DECLARE_OP_ADAPTER(IdentityN) DECLARE_OP_ADAPTER(IdentityN)
DECLARE_OP_USE_DYN_OUTPUT(IdentityN) DECLARE_OP_USE_DYN_OUTPUT(IdentityN)
DECLARE_OP_ADAPTER(Unique)
DECLARE_OP_USE_OUTPUT(Unique)
DECLARE_OP_ADAPTER(BroadcastGradientArgs)
DECLARE_OP_USE_OUTPUT(BroadcastGradientArgs)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_

View File

@ -40,4 +40,11 @@ ATTR_MAP(TensorArrayGather) = {{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())},
{"element_shape", ATTR_DESC(element_shape, AnyTraits<std::vector<int64_t>>())}}; {"element_shape", ATTR_DESC(element_shape, AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(TensorArrayGather) = {{0, OUTPUT_DESC(value)}}; OUTPUT_MAP(TensorArrayGather) = {{0, OUTPUT_DESC(value)}};
REG_ADPT_DESC(TensorArrayGather, kNameTensorArrayGather, ADPT_DESC(TensorArrayGather)) REG_ADPT_DESC(TensorArrayGather, kNameTensorArrayGather, ADPT_DESC(TensorArrayGather))
// DynamicStitch
INPUT_MAP(DynamicStitch) = EMPTY_INPUT_MAP;
DYN_INPUT_MAP(DynamicStitch) = {{1, DYN_INPUT_DESC(indices)}, {2, DYN_INPUT_DESC(x)}};
ATTR_MAP(DynamicStitch) = EMPTY_ATTR_MAP;
OUTPUT_MAP(DynamicStitch) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(DynamicStitch, kNameDynamicStitch, ADPT_DESC(DynamicStitch))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -29,5 +29,8 @@ DECLARE_OP_USE_OUTPUT(TensorArrayWrite)
DECLARE_OP_ADAPTER(TensorArrayGather) DECLARE_OP_ADAPTER(TensorArrayGather)
DECLARE_OP_USE_OUTPUT(TensorArrayGather) DECLARE_OP_USE_OUTPUT(TensorArrayGather)
DECLARE_OP_ADAPTER(DynamicStitch)
DECLARE_OP_USE_OUTPUT(DynamicStitch)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_

View File

@ -14,5 +14,6 @@
# ============================================================================ # ============================================================================
import os import os
os.environ['MS_ENABLE_GE'] = '1' os.environ['MS_ENABLE_GE'] = '1'
os.environ['MS_GE_TRAIN'] = '0' os.environ['MS_GE_TRAIN'] = '0'

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import ge_train_env # pylint: disable=unused-import from tests.st.ge import ge_train_env # pylint: disable=unused-import
import mindspore as ms import mindspore as ms
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn

View File

@ -15,7 +15,7 @@
""" test ge frontend pass `AvgPoolGradForGE` """ """ test ge frontend pass `AvgPoolGradForGE` """
import numpy as np import numpy as np
import ge_infer_env # pylint: disable=unused-import from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as op from mindspore.ops import operations as op

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import ge_infer_env # pylint: disable=unused-import from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import mindspore import mindspore
from mindspore import context, nn, Tensor, ops from mindspore import context, nn, Tensor, ops

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import ge_infer_env # pylint: disable=unused-import from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import mindspore import mindspore
from mindspore import context, nn, Tensor, ops from mindspore import context, nn, Tensor, ops

View File

@ -15,7 +15,7 @@
""" test ge frontend pass `DropoutForGE` `DropoutGradForGE` """ """ test ge frontend pass `DropoutForGE` `DropoutGradForGE` """
import numpy as np import numpy as np
import ge_infer_env # pylint: disable=unused-import from tests.st.ge import ge_infer_env # pylint: disable=unused-import
from mindspore import ops, nn, context, Tensor from mindspore import ops, nn, context, Tensor
from mindspore.ops.composite import GradOperation from mindspore.ops.composite import GradOperation

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import ge_infer_env # pylint: disable=unused-import from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import mindspore import mindspore
from mindspore import context, nn, Tensor, ops from mindspore import context, nn, Tensor, ops

View File

@ -15,7 +15,7 @@
""" test ge frontend pass and op `TensorArray`""" """ test ge frontend pass and op `TensorArray`"""
import numpy as np import numpy as np
import ge_infer_env # pylint: disable=unused-import from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import mindspore.context as context import mindspore.context as context
from mindspore import nn from mindspore import nn
from mindspore import Tensor from mindspore import Tensor

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import pytest import pytest
import ge_test_utils as utils import tests.st.ge.ge_test_utils as utils
@pytest.mark.level0 @pytest.mark.level0

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import pytest import pytest
import ge_test_utils as utils import tests.st.ge.ge_test_utils as utils
@pytest.mark.level1 @pytest.mark.level1

View File

@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import inspect
import os import os
import sys import sys
def run_testcase(file_name, case_name=""): def run_testcase(file_name, case_name=""):
caller_working_dir = os.path.dirname(inspect.stack()[1][1])
os.chdir(caller_working_dir)
log_file = file_name + "_" + case_name + '.log' log_file = file_name + "_" + case_name + '.log'
if case_name == "": if case_name == "":
ret = os.system(f'{sys.executable} {file_name}.py &> {log_file}') ret = os.system(f'{sys.executable} {file_name}.py &> {log_file}')

View File

@ -14,5 +14,6 @@
# ============================================================================ # ============================================================================
import os import os
os.environ['MS_ENABLE_GE'] = '1' os.environ['MS_ENABLE_GE'] = '1'
os.environ['MS_GE_TRAIN'] = '1' os.environ['MS_GE_TRAIN'] = '1'

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _inner_ops
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.args = _inner_ops.DynamicBroadcastGradientArgs()
def construct(self, s0, s1):
r = self.args(s0, s1)
return (r[0], r[1])
def test_broadcast_gradient_args():
"""
Feature: for DynamicBroadcastGradientArgs op
Description: inputs are two shapes
Expectation: the result is correct
"""
shape0 = (4, 2, 1)
shape1 = (2, 7)
net = Net()
r0, r1 = net(shape0, shape1)
r0_expected = [2]
r1_expected = [0]
assert np.array_equal(r0_expected, r0.asnumpy())
assert np.array_equal(r1_expected, r1.asnumpy())
if __name__ == "__main__":
test_broadcast_gradient_args()

View File

@ -0,0 +1,81 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
out_channel = 4
kernel_size = 1
self.conv_filter = G.Conv2DBackpropFilter(out_channel,
kernel_size,
pad_mode="valid",
pad=0,
mode=1,
stride=1,
dilation=1,
group=1)
self.w = Parameter(
initializer(Tensor(np.array(
[[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]),
name='w')
self.x = Parameter(initializer(Tensor(np.array([[[
[3, 0, 1, 2, 7, 4],
[1, 5, 8, 9, 3, 1],
[2, 7, 2, 5, 1, 3],
[0, 1, 3, 1, 7, 8],
[4, 2, 1, 6, 2, 8],
[2, 4, 5, 2, 3, 9]]]]).astype(np.float32)), [1, 1, 6, 6]), name='x')
self.out = Parameter(initializer(Tensor(np.array([[[
[-5, -4, 0, 8],
[-10, -2, 2, 3],
[0, -2, -4, -7],
[-3, -2, -3, -16]]]]).astype(np.float32)), [1, 1, 4, 4]), name='y')
self.get_shape = P.Shape()
@ms_function
def construct(self):
return self.conv_filter(self.out, self.x, self.get_shape(self.w))
def test_conv2d_backprop_filter():
"""
Feature: for Conv2DBackpropFilter op
Description: inputs are integers
Expectation: the result is correct
"""
conv2d_filter = Net()
output = conv2d_filter()
expect = np.array([[[[-60, -142, -265],
[-104, -211, -322],
[-102, -144, -248]]]]).astype(np.float32)
assert (output.asnumpy() == expect).all()
if __name__ == "__main__":
test_conv2d_backprop_filter()

View File

@ -0,0 +1,83 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
out_channel = 4
kernel_size = 1
self.conv_input = P.Conv2DBackpropInput(out_channel,
kernel_size,
pad_mode="valid",
pad=0,
mode=1,
stride=1,
dilation=1,
group=1)
self.w = Parameter(
initializer(Tensor(np.array(
[[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]),
name='w')
self.x = Parameter(initializer(Tensor(np.array([[[
[3, 0, 1, 2, 7, 4],
[1, 5, 8, 9, 3, 1],
[2, 7, 2, 5, 1, 3],
[0, 1, 3, 1, 7, 8],
[4, 2, 1, 6, 2, 8],
[2, 4, 5, 2, 3, 9]]]]).astype(np.float32)), [1, 1, 6, 6]), name='x')
self.out = Parameter(initializer(Tensor(np.array([[[
[-5, -4, 0, 8],
[-10, -2, 2, 3],
[0, -2, -4, -7],
[-3, -2, -3, -16]]]]).astype(np.float32)), [1, 1, 4, 4]), name='y')
self.get_shape = P.Shape()
@ms_function
def construct(self):
return self.conv_input(self.out, self.w, self.get_shape(self.x))
def test_conv2d_backprop_input():
"""
Feature: for Conv2DBackpropInput op
Description: inputs are integers
Expectation: the result is correct
"""
conv2d_input = Net()
output = conv2d_input()
expect = np.array([[[[-5, -4, 5, 12, 0, -8],
[-15, -6, 17, 17, -2, -11],
[-15, -8, 13, 12, 2, -4],
[-13, -6, 8, -14, 5, 20],
[-3, -4, -4, -19, 7, 23],
[-3, -2, 0, -14, 3, 16]]]]).astype(np.float32)
assert (output.asnumpy() == expect).all()
if __name__ == "__main__":
test_conv2d_backprop_input()

View File

@ -0,0 +1,83 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
out_channel = 4
kernel_size = 1
self.conv_input = P.Conv2DTranspose(out_channel,
kernel_size,
pad_mode="valid",
pad=0,
mode=1,
stride=1,
dilation=1,
group=1)
self.w = Parameter(
initializer(Tensor(np.array(
[[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]),
name='w')
self.x = Parameter(initializer(Tensor(np.array([[[
[3, 0, 1, 2, 7, 4],
[1, 5, 8, 9, 3, 1],
[2, 7, 2, 5, 1, 3],
[0, 1, 3, 1, 7, 8],
[4, 2, 1, 6, 2, 8],
[2, 4, 5, 2, 3, 9]]]]).astype(np.float32)), [1, 1, 6, 6]), name='x')
self.out = Parameter(initializer(Tensor(np.array([[[
[-5, -4, 0, 8],
[-10, -2, 2, 3],
[0, -2, -4, -7],
[-3, -2, -3, -16]]]]).astype(np.float32)), [1, 1, 4, 4]), name='y')
self.get_shape = P.Shape()
@ms_function
def construct(self):
return self.conv_input(self.out, self.w, self.get_shape(self.x))
def test_conv2d_transpose():
"""
Feature: for Conv2DTranspose op
Description: inputs are integers
Expectation: the result is correct
"""
conv2d_input = Net()
output = conv2d_input()
expect = np.array([[[[-5, -4, 5, 12, 0, -8],
[-15, -6, 17, 17, -2, -11],
[-15, -8, 13, 12, 2, -4],
[-13, -6, 8, -14, 5, 20],
[-3, -4, -4, -19, 7, 23],
[-3, -2, 0, -14, 3, 16]]]]).astype(np.float32)
assert (output.asnumpy() == expect).all()
if __name__ == "__main__":
test_conv2d_transpose()

View File

@ -0,0 +1,54 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import numpy as np
import mindspore.context as context
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.operations as ops
from mindspore import Tensor
from mindspore.ops.operations import _inner_ops as inner
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.d_shape = ops.TensorShape()
self.d_broadcastto = inner.DynamicBroadcastTo()
def construct(self, data, shape):
shape = self.d_shape(shape)
return self.d_broadcastto(data, shape)
def test_dynamic_broadcast_to():
"""
Feature: for DynamicBroadcastTo op
Description: inputs are data and shape
Expectation: the result is correct
"""
data = Tensor(np.array([1, 2, 3]), mstype.float32)
shape = Tensor(np.zeros((2, 3)), mstype.int64)
expect_data = np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32)
net = Net()
output = net(data, shape)
assert np.array_equal(output.asnumpy(), expect_data)
if __name__ == "__main__":
test_dynamic_broadcast_to()

View File

@ -0,0 +1,94 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import tests.st.ge.ge_test_utils as utils
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_broadcast_gradient_args():
"""
Feature: for DynamicBroadcastGradientArgs op
Description: inputs are two shapes
Expectation: the result is correct
"""
utils.run_testcase('broadcast_gradient_args')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_conv2d_backprop_filter():
"""
Feature: for Conv2DBackpropFilter op
Description: inputs are integers
Expectation: the result is correct
"""
utils.run_testcase('conv2d_backprop_filter')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_conv2d_backprop_input():
"""
Feature: for Conv2DBackpropInput op
Description: inputs are integers
Expectation: the result is correct
"""
utils.run_testcase('conv2d_backprop_input')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_conv2d_transpose():
"""
Feature: for Conv2DTranspose op
Description: inputs are integers
Expectation: the result is correct
"""
utils.run_testcase('conv2d_transpose')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_broadcast_to():
"""
Feature: for DynamicBroadcastTo op
Description: inputs are data and shape
Expectation: the result is correct
"""
utils.run_testcase('dynamic_broadcast_to')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unique():
"""
Feature: for Unique op
Description: inputs are integers
Expectation: the result is correct
"""
utils.run_testcase('unique')

53
tests/st/ge/ops/unique.py Normal file
View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.unique = P.Unique()
def construct(self, x):
x = self.unique(x)
return (x[0], x[1])
def test_unique():
"""
Feature: for Unique op
Description: inputs are integers
Expectation: the result is correct
"""
x = Tensor(np.array([1, 1, 2, 3, 3, 3]), mstype.int32)
unique = Net()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([0, 0, 1, 2, 2, 2])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
if __name__ == "__main__":
test_unique()