forked from mindspore-Ecosystem/mindspore
!44526 support dynamic input for GE
Merge pull request !44526 from xulei/ge_dynamic_shape
This commit is contained in:
commit
d147d1f202
|
@ -25,6 +25,20 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
template <typename Map, typename K = typename Map::key_type, typename V = typename Map::mapped_type>
|
||||
std::string MapToString(const Map &value) {
|
||||
std::stringstream buffer;
|
||||
buffer << "{";
|
||||
for (auto it = value.begin(); it != value.end(); it++) {
|
||||
if (it != value.begin()) {
|
||||
buffer << ", ";
|
||||
}
|
||||
buffer << it->first << ": " << it->second;
|
||||
}
|
||||
buffer << "}";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string GetErrorMessage(bool add_title = false);
|
||||
std::string GetWarningMessage();
|
||||
void SetErrorManagerContext();
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
|
||||
#include "plugin/device/ascend/optimizer/ge_optimization.h"
|
||||
#include "plugin/device/ascend/hal/common/ascend_utils.h"
|
||||
#include "runtime/config.h"
|
||||
#include "runtime/dev.h"
|
||||
#include "distributed/init.h"
|
||||
|
@ -43,6 +44,8 @@ namespace mindspore {
|
|||
namespace device {
|
||||
namespace ascend {
|
||||
namespace {
|
||||
using mindspore::transform::OptionMap;
|
||||
|
||||
constexpr auto kMindsporeDumpConfig = "MINDSPORE_DUMP_CONFIG";
|
||||
constexpr char kGeDumpMode[3][7] = {"all", "input", "output"};
|
||||
|
||||
|
@ -101,6 +104,38 @@ transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
|
|||
return res;
|
||||
}
|
||||
|
||||
std::string ShapesToString(const ShapeArray &shapes) {
|
||||
std::stringstream buffer;
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
if (i != 0) {
|
||||
buffer << ",";
|
||||
}
|
||||
buffer << "[";
|
||||
const auto &shape = shapes[i];
|
||||
for (size_t j = 0; j < shape.size(); ++j) {
|
||||
if (j != 0) {
|
||||
buffer << ",";
|
||||
}
|
||||
buffer << shape[j];
|
||||
}
|
||||
buffer << "]";
|
||||
}
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
OptionMap GetComputeGraphOptions(const ShapeArray &input_shapes, bool is_dynamic_shape) {
|
||||
OptionMap options{};
|
||||
if (common::GetEnv("GE_TRAIN") == "1") {
|
||||
(void)options.emplace("ge.exec.variable_acc", "1");
|
||||
}
|
||||
if (!is_dynamic_shape) {
|
||||
return options;
|
||||
}
|
||||
(void)options.emplace("ge.exec.dynamicGraphExecuteMode", "dynamic_execute");
|
||||
(void)options.emplace("ge.exec.dataInputsShapeRange", ShapesToString(input_shapes));
|
||||
return options;
|
||||
}
|
||||
|
||||
bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
auto converter = transform::NewConverter(anf_graph);
|
||||
|
@ -121,11 +156,9 @@ bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &
|
|||
std::string graph_name = anf_graph->ToString();
|
||||
std::string init_graph = "init_subgraph." + graph_name;
|
||||
std::string checkpoint_name = "save." + GetGraphName(anf_graph);
|
||||
if (common::GetEnv("GE_TRAIN") == "1") {
|
||||
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), {{"ge.exec.variable_acc", "1"}});
|
||||
} else {
|
||||
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter));
|
||||
}
|
||||
const auto options = GetComputeGraphOptions(converter->input_shapes(), converter->dynamic_shape_inputs());
|
||||
MS_LOG(INFO) << "Set options of compute graph: " << graph_name << " to " << MapToString(options);
|
||||
(void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), options);
|
||||
(void)transform::AddGraph(init_graph, transform::GetInitGraph(converter));
|
||||
(void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter));
|
||||
|
||||
|
@ -488,7 +521,7 @@ void GeDeviceContext::InitGe(const std::shared_ptr<MsContext> &inst_context) {
|
|||
{
|
||||
// Release GIL before calling into (potentially long-running) C++ code
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
|
||||
if (::ge::GEInitialize(ge_options) != ::ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Initialize GE failed!";
|
||||
}
|
||||
}
|
||||
|
@ -670,7 +703,7 @@ bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context)
|
|||
std::string exName(abi::__cxa_current_exception_type()->name());
|
||||
MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
|
||||
}
|
||||
if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
|
||||
if (::ge::GEFinalize() != ::ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(WARNING) << "Finalize GE failed!";
|
||||
}
|
||||
inst_context->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <stack>
|
||||
#include <unordered_set>
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
@ -392,14 +393,16 @@ DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Processing input with MakeDatasetHandler
|
||||
// Processing input with MakeDatasetHandler and check whether input is dynamic
|
||||
for (auto &it : anf_graph_->inputs()) {
|
||||
auto op_itor = op_cache_.find(it.get()); // converted node
|
||||
if (it->isa<Parameter>() && op_itor != op_cache_.end()) {
|
||||
string name = std::static_pointer_cast<Parameter>(it)->name();
|
||||
const auto ¶m = std::static_pointer_cast<Parameter>(it);
|
||||
string name = param->name();
|
||||
auto tensor_itor = tensors.find(name); // in init value map
|
||||
if (tensor_itor == tensors.end()) {
|
||||
DfGraphConvertor::MakeDatasetHandler(name, input_idx, it);
|
||||
MakeDatasetHandler(name, input_idx, it);
|
||||
AddGraphDynamicInput(param);
|
||||
input_idx++;
|
||||
}
|
||||
}
|
||||
|
@ -1941,6 +1944,19 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
|
|||
DfGraphConvertor::SetOpInput(adpt, cnode);
|
||||
}
|
||||
|
||||
void DfGraphConvertor::AddGraphDynamicInput(const ParameterPtr ¶m) {
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
const auto &base_shape = param->Shape();
|
||||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
const auto &shape = base_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
const auto &sv = shape->shape();
|
||||
if (std::any_of(sv.cbegin(), sv.cend(), [](const auto e) { return e == -1; })) {
|
||||
dynamic_shape_inputs_ = true;
|
||||
}
|
||||
(void)input_shapes_.emplace_back(sv);
|
||||
}
|
||||
|
||||
std::string DfGraphConvertor::GetGNodeName(const ::ge::GNode &node) const {
|
||||
::ge::AscendString name;
|
||||
auto ret = node.GetName(name);
|
||||
|
@ -2224,7 +2240,7 @@ OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
|
||||
void DfGraphConvertor::ConvertTopK(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
|
||||
auto value_ptr = node->input(2)->cast<ValueNodePtr>();
|
||||
|
@ -2372,7 +2388,7 @@ std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) const {
|
|||
return cur_value;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
|
||||
void DfGraphConvertor::ConvertReshape(const CNodePtr &node) {
|
||||
MS_LOG(INFO) << "Convert the second input of reshape to op attr.";
|
||||
const auto kInputNum = 3;
|
||||
if (node->size() < kInputNum) {
|
||||
|
@ -2386,10 +2402,7 @@ void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
|
|||
auto op = adpt->generate(node);
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
// get shape form attr
|
||||
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_EXCEPTION_IF_NULL(value_node->value());
|
||||
auto primitive = value_node->value()->cast<PrimitivePtr>();
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto value = primitive->GetAttr("shape");
|
||||
std::vector<int64_t> list;
|
||||
|
@ -2399,7 +2412,34 @@ void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
|
|||
op_cache_[node.get()] = op;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertAllReduce(const CNodePtr node) {
|
||||
void DfGraphConvertor::ConvertDynamicStitch(const CNodePtr &node) {
|
||||
MS_LOG(INFO) << "Convert and set 'N' attr of DynamicStitch.";
|
||||
OpAdapterPtr adpt = FindAdapter(node, training_);
|
||||
if (adpt == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto op = adpt->generate(node);
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
int64_t input_length = 0;
|
||||
auto indices = node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
if (indices->isa<CNode>()) {
|
||||
input_length = SizeToLong(indices->cast<CNodePtr>()->size()) - 1;
|
||||
} else if (IsValueNode<ValueSequence>(indices)) {
|
||||
const auto tuple = GetValueNode<ValueSequencePtr>(indices);
|
||||
MS_EXCEPTION_IF_NULL(tuple);
|
||||
input_length = SizeToLong(tuple->size());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Input 1 of DynamicStitch is neither CNode nor ValueNode contains ValueSequence, but "
|
||||
<< indices->ToString() << ", can not set 'N' attr.";
|
||||
}
|
||||
|
||||
(void)op->SetAttr("N", input_length);
|
||||
MS_LOG(INFO) << "Set 'N' attr of DynamicStitch to " << input_length;
|
||||
op_cache_[node.get()] = op;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertAllReduce(const CNodePtr &node) {
|
||||
MS_LOG(INFO) << "Add AllReduce fusion_id";
|
||||
OpAdapterPtr adpt = FindAdapter(node, training_);
|
||||
if (adpt == nullptr) {
|
||||
|
@ -2408,10 +2448,7 @@ void DfGraphConvertor::ConvertAllReduce(const CNodePtr node) {
|
|||
auto op = adpt->generate(node);
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
// get shape form attr
|
||||
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_EXCEPTION_IF_NULL(value_node->value());
|
||||
auto primitive = value_node->value()->cast<PrimitivePtr>();
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto fusion_value = primitive->GetAttr("fusion");
|
||||
auto fusion = GetValue<int64_t>(fusion_value);
|
||||
|
@ -2425,7 +2462,8 @@ void DfGraphConvertor::ConvertAllReduce(const CNodePtr node) {
|
|||
op_cache_[node.get()] = op;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertConv2D(const CNodePtr node) {
|
||||
void DfGraphConvertor::ConvertConv2D(const CNodePtr &node) {
|
||||
MS_LOG(INFO) << "Convert and set 'padding' attr for Conv2D-like op.";
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
OpAdapterPtr adpt = FindAdapter(node, training_);
|
||||
if (adpt == nullptr) {
|
||||
|
@ -2433,20 +2471,33 @@ void DfGraphConvertor::ConvertConv2D(const CNodePtr node) {
|
|||
}
|
||||
auto op = adpt->generate(node);
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_EXCEPTION_IF_NULL(value_node->value());
|
||||
auto primitive = value_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto value = primitive->GetAttr("padding");
|
||||
if (value != nullptr) {
|
||||
std::string pad_mode = GetValue<std::string>(value);
|
||||
(void)op->SetAttr("padding", pad_mode);
|
||||
}
|
||||
op_cache_[node.get()] = op;
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
std::string pad_mode;
|
||||
if (auto value = primitive->GetAttr("padding"); value != nullptr) {
|
||||
pad_mode = GetValue<std::string>(value);
|
||||
} else if (auto value = primitive->GetAttr("pad_mode"); value != nullptr) {
|
||||
// Get 'pad_mode' attr and set it to 'padding' attr for ge
|
||||
const mindspore::HashMap<int64_t, std::string> pad_mode_map{{1, "SAME"}, {2, "VALID"}};
|
||||
if (value->isa<StringImm>()) {
|
||||
pad_mode = GetValue<std::string>(value);
|
||||
(void)std::transform(pad_mode.cbegin(), pad_mode.cend(), pad_mode.begin(), toupper);
|
||||
} else if (auto it = pad_mode_map.find(GetValue<int64_t>(value)); it != pad_mode_map.cend()) {
|
||||
// 'pad_mode' attr could be an enumeration
|
||||
pad_mode = it->second;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(INFO) << "Node: " << node->fullname_with_scope() << " has no 'padding' or 'pad_mode' attr";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Set 'padding' attr of node: " << node->fullname_with_scope() << " to " << pad_mode;
|
||||
(void)op->SetAttr("padding", pad_mode);
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertOCRRecPreHandle(const CNodePtr node) {
|
||||
void DfGraphConvertor::ConvertOCRRecPreHandle(const CNodePtr &node) {
|
||||
MS_LOG(INFO) << "Add OCRRecognitionPreHandle _op_max_shape attr";
|
||||
OpAdapterPtr adpt = FindAdapter(node, training_);
|
||||
if (adpt == nullptr) {
|
||||
|
@ -2455,10 +2506,7 @@ void DfGraphConvertor::ConvertOCRRecPreHandle(const CNodePtr node) {
|
|||
auto op = adpt->generate(node);
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
// get shape form attr
|
||||
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_EXCEPTION_IF_NULL(value_node->value());
|
||||
auto primitive = value_node->value()->cast<PrimitivePtr>();
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto value = primitive->GetAttr("_op_max_shape");
|
||||
if (value == nullptr) {
|
||||
|
@ -2517,33 +2565,26 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
|||
return false;
|
||||
}
|
||||
|
||||
// Convert TopK second input from int64 to int32.
|
||||
if (name == prim::kPrimTopK->name()) {
|
||||
ConvertTopK(node);
|
||||
return true;
|
||||
}
|
||||
const mindspore::HashMap<std::string, std::function<void(decltype(this), const CNodePtr &)>>
|
||||
auxiliary_node_converters{
|
||||
// Convert TopK second input from int64 to int32.
|
||||
{prim::kPrimTopK->name(), &DfGraphConvertor::ConvertTopK},
|
||||
// Convert Reshape add const input to attr(shape)
|
||||
{prim::kPrimReshape->name(), &DfGraphConvertor::ConvertReshape},
|
||||
{prim::kPrimOCRRecognitionPreHandle->name(), &DfGraphConvertor::ConvertOCRRecPreHandle},
|
||||
// Add attr 'pad_mode' to Conv2D-like op
|
||||
{prim::kPrimConv2D->name(), &DfGraphConvertor::ConvertConv2D},
|
||||
{prim::kPrimDepthwiseConv2dNative->name(), &DfGraphConvertor::ConvertConv2D},
|
||||
{kNameConv2DBackpropInputV2, &DfGraphConvertor::ConvertConv2D},
|
||||
{prim::kPrimConv2DBackpropInput->name(), &DfGraphConvertor::ConvertConv2D},
|
||||
{prim::kPrimConv2DBackpropFilter->name(), &DfGraphConvertor::ConvertConv2D},
|
||||
// Add attr 'N' to DynamicStitch
|
||||
{prim::kPrimDynamicStitch->name(), &DfGraphConvertor::ConvertDynamicStitch},
|
||||
{prim::kPrimAllReduce->name(), &DfGraphConvertor::ConvertAllReduce},
|
||||
};
|
||||
|
||||
// Convert Reshape add const input to attr(shape)
|
||||
if (name == prim::kPrimReshape->name()) {
|
||||
ConvertReshape(node);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (name == prim::kPrimOCRRecognitionPreHandle->name()) {
|
||||
ConvertOCRRecPreHandle(node);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Add attr pad mode to Conv2D
|
||||
if (name == prim::kPrimConv2D->name() || name == prim::kPrimDepthwiseConv2dNative->name() ||
|
||||
name == kNameConv2DBackpropInputV2) {
|
||||
ConvertConv2D(node);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (name == prim::kPrimAllReduce->name()) {
|
||||
ConvertAllReduce(node);
|
||||
return true;
|
||||
if (const auto it = auxiliary_node_converters.find(name); it != auxiliary_node_converters.cend()) {
|
||||
it->second(this, node);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
|
@ -156,6 +156,8 @@ class DfGraphConvertor {
|
|||
|
||||
bool is_training() const { return training_; }
|
||||
void set_training(bool is_training) { training_ = is_training; }
|
||||
bool dynamic_shape_inputs() const { return dynamic_shape_inputs_; }
|
||||
std::vector<ShapeVector> input_shapes() { return input_shapes_; }
|
||||
|
||||
protected:
|
||||
void InitLoopVar(std::vector<::ge::Operator> *init_input);
|
||||
|
@ -178,15 +180,16 @@ class DfGraphConvertor {
|
|||
OperatorPtr ConvertValueNode(ValueNodePtr node);
|
||||
void SaveParamFormat(CNodePtr node);
|
||||
void GetBranchNodeInput(const CNodePtr node);
|
||||
void ConvertTopK(const CNodePtr node);
|
||||
void ConvertTopK(const CNodePtr &node);
|
||||
void ConvertResizeBilinear(const FuncGraphPtr anf_graph) const;
|
||||
void ConvertSpaceBatchNd(const FuncGraphPtr anf_graph) const;
|
||||
void ConvertTile(const FuncGraphPtr anf_graph) const;
|
||||
AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const;
|
||||
void ConvertReshape(const CNodePtr node);
|
||||
void ConvertAllReduce(const CNodePtr node);
|
||||
void ConvertOCRRecPreHandle(const CNodePtr node);
|
||||
void ConvertConv2D(const CNodePtr node);
|
||||
void ConvertReshape(const CNodePtr &node);
|
||||
void ConvertAllReduce(const CNodePtr &node);
|
||||
void ConvertOCRRecPreHandle(const CNodePtr &node);
|
||||
void ConvertConv2D(const CNodePtr &node);
|
||||
void ConvertDynamicStitch(const CNodePtr &node);
|
||||
std::vector<int64_t> CastToInt(const ValuePtr &value) const;
|
||||
bool CheckCNode(const std::string &name, const CNodePtr node);
|
||||
void SetNodeInput(AnfNodePtr node);
|
||||
|
@ -198,6 +201,7 @@ class DfGraphConvertor {
|
|||
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
|
||||
void UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
|
||||
void AddGraphConstInput(const OperatorPtr &op);
|
||||
void AddGraphDynamicInput(const ParameterPtr ¶m);
|
||||
AnfNodePtr ParseLoadInput(const CNodePtr &cnode) const;
|
||||
void SetGraphInputs(std::vector<Operator> *inputs);
|
||||
void TransformConstOp(const CNodePtr &node, const AnfNodePtr &pred);
|
||||
|
@ -264,6 +268,7 @@ class DfGraphConvertor {
|
|||
std::vector<OperatorPtr> init_ops_;
|
||||
std::vector<OperatorPtr> broadcast_ops_;
|
||||
std::vector<AnfNodePtr> inputs_;
|
||||
ShapeArray input_shapes_;
|
||||
OperatorPtr dataset_iter_getnext_;
|
||||
OperatorPtr queue_data_;
|
||||
OperatorPtr get_next_from_queue_;
|
||||
|
@ -271,6 +276,7 @@ class DfGraphConvertor {
|
|||
bool training_ = false;
|
||||
bool distribute_ = false;
|
||||
bool use_inputs_ = false;
|
||||
bool dynamic_shape_inputs_ = false;
|
||||
|
||||
AnfNodePtr while_cond_node_ = nullptr;
|
||||
mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> while_dfgraph_cache_;
|
||||
|
|
|
@ -396,6 +396,9 @@ constexpr const char kNameTransData[] = "TransData";
|
|||
constexpr const char kNameSend[] = "Send";
|
||||
constexpr const char kNameReceive[] = "Receive";
|
||||
constexpr const char kNameIndexAdd[] = "IndexAdd";
|
||||
constexpr const char kNameUnique[] = "Unique";
|
||||
constexpr const char kNameDynamicBroadcastGradientArgs[] = "DynamicBroadcastGradientArgs";
|
||||
constexpr const char kNameDynamicStitch[] = "DynamicStitch";
|
||||
|
||||
class OpAdapterDesc;
|
||||
|
||||
|
|
|
@ -160,5 +160,15 @@ REG_ADPT_DESC(IdentityNMakeTuple, kNameMakeTuple, ADPT_DESC(IdentityN))
|
|||
REG_ADPT_DESC(IdentityNDepend, kNameDepend, ADPT_DESC(IdentityN))
|
||||
REG_ADPT_DESC(IdentityNReturn, kNameReturn, ADPT_DESC(IdentityN))
|
||||
|
||||
// IdentityN
|
||||
// Unique
|
||||
INPUT_MAP(Unique) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Unique) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Unique) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(idx)}};
|
||||
REG_ADPT_DESC(Unique, kNameUnique, ADPT_DESC(Unique))
|
||||
|
||||
// BroadcastGradientArgs
|
||||
INPUT_MAP(BroadcastGradientArgs) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
||||
ATTR_MAP(BroadcastGradientArgs) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(BroadcastGradientArgs) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}};
|
||||
REG_ADPT_DESC(BroadcastGradientArgs, kNameDynamicBroadcastGradientArgs, ADPT_DESC(BroadcastGradientArgs))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -84,5 +84,11 @@ DECLARE_OP_USE_OUTPUT(Identity)
|
|||
|
||||
DECLARE_OP_ADAPTER(IdentityN)
|
||||
DECLARE_OP_USE_DYN_OUTPUT(IdentityN)
|
||||
|
||||
DECLARE_OP_ADAPTER(Unique)
|
||||
DECLARE_OP_USE_OUTPUT(Unique)
|
||||
|
||||
DECLARE_OP_ADAPTER(BroadcastGradientArgs)
|
||||
DECLARE_OP_USE_OUTPUT(BroadcastGradientArgs)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_
|
||||
|
|
|
@ -40,4 +40,11 @@ ATTR_MAP(TensorArrayGather) = {{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())},
|
|||
{"element_shape", ATTR_DESC(element_shape, AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(TensorArrayGather) = {{0, OUTPUT_DESC(value)}};
|
||||
REG_ADPT_DESC(TensorArrayGather, kNameTensorArrayGather, ADPT_DESC(TensorArrayGather))
|
||||
|
||||
// DynamicStitch
|
||||
INPUT_MAP(DynamicStitch) = EMPTY_INPUT_MAP;
|
||||
DYN_INPUT_MAP(DynamicStitch) = {{1, DYN_INPUT_DESC(indices)}, {2, DYN_INPUT_DESC(x)}};
|
||||
ATTR_MAP(DynamicStitch) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(DynamicStitch) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(DynamicStitch, kNameDynamicStitch, ADPT_DESC(DynamicStitch))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -29,5 +29,8 @@ DECLARE_OP_USE_OUTPUT(TensorArrayWrite)
|
|||
|
||||
DECLARE_OP_ADAPTER(TensorArrayGather)
|
||||
DECLARE_OP_USE_OUTPUT(TensorArrayGather)
|
||||
|
||||
DECLARE_OP_ADAPTER(DynamicStitch)
|
||||
DECLARE_OP_USE_OUTPUT(DynamicStitch)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_
|
||||
|
|
|
@ -14,5 +14,6 @@
|
|||
# ============================================================================
|
||||
import os
|
||||
|
||||
|
||||
os.environ['MS_ENABLE_GE'] = '1'
|
||||
os.environ['MS_GE_TRAIN'] = '0'
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import ge_train_env # pylint: disable=unused-import
|
||||
from tests.st.ge import ge_train_env # pylint: disable=unused-import
|
||||
import mindspore as ms
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test ge frontend pass `AvgPoolGradForGE` """
|
||||
import numpy as np
|
||||
|
||||
import ge_infer_env # pylint: disable=unused-import
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as op
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import ge_infer_env # pylint: disable=unused-import
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import mindspore
|
||||
from mindspore import context, nn, Tensor, ops
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import ge_infer_env # pylint: disable=unused-import
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import mindspore
|
||||
from mindspore import context, nn, Tensor, ops
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test ge frontend pass `DropoutForGE` `DropoutGradForGE` """
|
||||
import numpy as np
|
||||
|
||||
import ge_infer_env # pylint: disable=unused-import
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
from mindspore import ops, nn, context, Tensor
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import ge_infer_env # pylint: disable=unused-import
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import mindspore
|
||||
from mindspore import context, nn, Tensor, ops
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test ge frontend pass and op `TensorArray`"""
|
||||
import numpy as np
|
||||
|
||||
import ge_infer_env # pylint: disable=unused-import
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import mindspore.context as context
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import ge_test_utils as utils
|
||||
import tests.st.ge.ge_test_utils as utils
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import ge_test_utils as utils
|
||||
import tests.st.ge.ge_test_utils as utils
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
|
|
@ -12,11 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def run_testcase(file_name, case_name=""):
|
||||
caller_working_dir = os.path.dirname(inspect.stack()[1][1])
|
||||
os.chdir(caller_working_dir)
|
||||
log_file = file_name + "_" + case_name + '.log'
|
||||
if case_name == "":
|
||||
ret = os.system(f'{sys.executable} {file_name}.py &> {log_file}')
|
|
@ -14,5 +14,6 @@
|
|||
# ============================================================================
|
||||
import os
|
||||
|
||||
|
||||
os.environ['MS_ENABLE_GE'] = '1'
|
||||
os.environ['MS_GE_TRAIN'] = '1'
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops.operations import _inner_ops
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.args = _inner_ops.DynamicBroadcastGradientArgs()
|
||||
|
||||
def construct(self, s0, s1):
|
||||
r = self.args(s0, s1)
|
||||
return (r[0], r[1])
|
||||
|
||||
|
||||
def test_broadcast_gradient_args():
|
||||
"""
|
||||
Feature: for DynamicBroadcastGradientArgs op
|
||||
Description: inputs are two shapes
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
shape0 = (4, 2, 1)
|
||||
shape1 = (2, 7)
|
||||
net = Net()
|
||||
r0, r1 = net(shape0, shape1)
|
||||
r0_expected = [2]
|
||||
r1_expected = [0]
|
||||
|
||||
assert np.array_equal(r0_expected, r0.asnumpy())
|
||||
assert np.array_equal(r1_expected, r1.asnumpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_broadcast_gradient_args()
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
out_channel = 4
|
||||
kernel_size = 1
|
||||
self.conv_filter = G.Conv2DBackpropFilter(out_channel,
|
||||
kernel_size,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
mode=1,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1)
|
||||
self.w = Parameter(
|
||||
initializer(Tensor(np.array(
|
||||
[[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]),
|
||||
name='w')
|
||||
self.x = Parameter(initializer(Tensor(np.array([[[
|
||||
[3, 0, 1, 2, 7, 4],
|
||||
[1, 5, 8, 9, 3, 1],
|
||||
[2, 7, 2, 5, 1, 3],
|
||||
[0, 1, 3, 1, 7, 8],
|
||||
[4, 2, 1, 6, 2, 8],
|
||||
[2, 4, 5, 2, 3, 9]]]]).astype(np.float32)), [1, 1, 6, 6]), name='x')
|
||||
self.out = Parameter(initializer(Tensor(np.array([[[
|
||||
[-5, -4, 0, 8],
|
||||
[-10, -2, 2, 3],
|
||||
[0, -2, -4, -7],
|
||||
[-3, -2, -3, -16]]]]).astype(np.float32)), [1, 1, 4, 4]), name='y')
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.conv_filter(self.out, self.x, self.get_shape(self.w))
|
||||
|
||||
|
||||
def test_conv2d_backprop_filter():
|
||||
"""
|
||||
Feature: for Conv2DBackpropFilter op
|
||||
Description: inputs are integers
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
conv2d_filter = Net()
|
||||
output = conv2d_filter()
|
||||
expect = np.array([[[[-60, -142, -265],
|
||||
[-104, -211, -322],
|
||||
[-102, -144, -248]]]]).astype(np.float32)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d_backprop_filter()
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
out_channel = 4
|
||||
kernel_size = 1
|
||||
self.conv_input = P.Conv2DBackpropInput(out_channel,
|
||||
kernel_size,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
mode=1,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1)
|
||||
self.w = Parameter(
|
||||
initializer(Tensor(np.array(
|
||||
[[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]),
|
||||
name='w')
|
||||
self.x = Parameter(initializer(Tensor(np.array([[[
|
||||
[3, 0, 1, 2, 7, 4],
|
||||
[1, 5, 8, 9, 3, 1],
|
||||
[2, 7, 2, 5, 1, 3],
|
||||
[0, 1, 3, 1, 7, 8],
|
||||
[4, 2, 1, 6, 2, 8],
|
||||
[2, 4, 5, 2, 3, 9]]]]).astype(np.float32)), [1, 1, 6, 6]), name='x')
|
||||
self.out = Parameter(initializer(Tensor(np.array([[[
|
||||
[-5, -4, 0, 8],
|
||||
[-10, -2, 2, 3],
|
||||
[0, -2, -4, -7],
|
||||
[-3, -2, -3, -16]]]]).astype(np.float32)), [1, 1, 4, 4]), name='y')
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.conv_input(self.out, self.w, self.get_shape(self.x))
|
||||
|
||||
|
||||
def test_conv2d_backprop_input():
|
||||
"""
|
||||
Feature: for Conv2DBackpropInput op
|
||||
Description: inputs are integers
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
conv2d_input = Net()
|
||||
output = conv2d_input()
|
||||
expect = np.array([[[[-5, -4, 5, 12, 0, -8],
|
||||
[-15, -6, 17, 17, -2, -11],
|
||||
[-15, -8, 13, 12, 2, -4],
|
||||
[-13, -6, 8, -14, 5, 20],
|
||||
[-3, -4, -4, -19, 7, 23],
|
||||
[-3, -2, 0, -14, 3, 16]]]]).astype(np.float32)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d_backprop_input()
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
out_channel = 4
|
||||
kernel_size = 1
|
||||
self.conv_input = P.Conv2DTranspose(out_channel,
|
||||
kernel_size,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
mode=1,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1)
|
||||
self.w = Parameter(
|
||||
initializer(Tensor(np.array(
|
||||
[[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]),
|
||||
name='w')
|
||||
self.x = Parameter(initializer(Tensor(np.array([[[
|
||||
[3, 0, 1, 2, 7, 4],
|
||||
[1, 5, 8, 9, 3, 1],
|
||||
[2, 7, 2, 5, 1, 3],
|
||||
[0, 1, 3, 1, 7, 8],
|
||||
[4, 2, 1, 6, 2, 8],
|
||||
[2, 4, 5, 2, 3, 9]]]]).astype(np.float32)), [1, 1, 6, 6]), name='x')
|
||||
self.out = Parameter(initializer(Tensor(np.array([[[
|
||||
[-5, -4, 0, 8],
|
||||
[-10, -2, 2, 3],
|
||||
[0, -2, -4, -7],
|
||||
[-3, -2, -3, -16]]]]).astype(np.float32)), [1, 1, 4, 4]), name='y')
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.conv_input(self.out, self.w, self.get_shape(self.x))
|
||||
|
||||
|
||||
def test_conv2d_transpose():
|
||||
"""
|
||||
Feature: for Conv2DTranspose op
|
||||
Description: inputs are integers
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
conv2d_input = Net()
|
||||
output = conv2d_input()
|
||||
expect = np.array([[[[-5, -4, 5, 12, 0, -8],
|
||||
[-15, -6, 17, 17, -2, -11],
|
||||
[-15, -8, 13, 12, 2, -4],
|
||||
[-13, -6, 8, -14, 5, 20],
|
||||
[-3, -4, -4, -19, 7, 23],
|
||||
[-3, -2, 0, -14, 3, 16]]]]).astype(np.float32)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d_transpose()
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.d_shape = ops.TensorShape()
|
||||
self.d_broadcastto = inner.DynamicBroadcastTo()
|
||||
|
||||
def construct(self, data, shape):
|
||||
shape = self.d_shape(shape)
|
||||
return self.d_broadcastto(data, shape)
|
||||
|
||||
|
||||
def test_dynamic_broadcast_to():
|
||||
"""
|
||||
Feature: for DynamicBroadcastTo op
|
||||
Description: inputs are data and shape
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
data = Tensor(np.array([1, 2, 3]), mstype.float32)
|
||||
shape = Tensor(np.zeros((2, 3)), mstype.int64)
|
||||
expect_data = np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32)
|
||||
net = Net()
|
||||
output = net(data, shape)
|
||||
assert np.array_equal(output.asnumpy(), expect_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_broadcast_to()
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import tests.st.ge.ge_test_utils as utils
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_broadcast_gradient_args():
|
||||
"""
|
||||
Feature: for DynamicBroadcastGradientArgs op
|
||||
Description: inputs are two shapes
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
utils.run_testcase('broadcast_gradient_args')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv2d_backprop_filter():
|
||||
"""
|
||||
Feature: for Conv2DBackpropFilter op
|
||||
Description: inputs are integers
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
utils.run_testcase('conv2d_backprop_filter')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv2d_backprop_input():
|
||||
"""
|
||||
Feature: for Conv2DBackpropInput op
|
||||
Description: inputs are integers
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
utils.run_testcase('conv2d_backprop_input')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv2d_transpose():
|
||||
"""
|
||||
Feature: for Conv2DTranspose op
|
||||
Description: inputs are integers
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
utils.run_testcase('conv2d_transpose')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_broadcast_to():
|
||||
"""
|
||||
Feature: for DynamicBroadcastTo op
|
||||
Description: inputs are data and shape
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
utils.run_testcase('dynamic_broadcast_to')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unique():
|
||||
"""
|
||||
Feature: for Unique op
|
||||
Description: inputs are integers
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
utils.run_testcase('unique')
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from tests.st.ge import ge_infer_env # pylint: disable=unused-import
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.unique(x)
|
||||
return (x[0], x[1])
|
||||
|
||||
|
||||
def test_unique():
|
||||
"""
|
||||
Feature: for Unique op
|
||||
Description: inputs are integers
|
||||
Expectation: the result is correct
|
||||
"""
|
||||
x = Tensor(np.array([1, 1, 2, 3, 3, 3]), mstype.int32)
|
||||
unique = Net()
|
||||
output = unique(x)
|
||||
expect1 = np.array([1, 2, 3])
|
||||
expect2 = np.array([0, 0, 1, 2, 2, 2])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_unique()
|
Loading…
Reference in New Issue