forked from mindspore-Ecosystem/mindspore
!683 WIP: specialize hyper map parameter
Merge pull request !683 from xychow/bypass-renorm-and-specialize-hypermap-parameter
This commit is contained in:
commit
da7d605e85
|
@ -42,6 +42,7 @@ using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher
|
|||
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
|
||||
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
|
||||
const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
||||
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
||||
|
||||
// ANF transform class
|
||||
// either a primitive or a func_graph
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <sstream>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pipeline/static_analysis/abstract_function.h"
|
||||
#include "pipeline/static_analysis/dshape.h"
|
||||
|
@ -334,6 +335,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
|
|||
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
|
||||
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->debug_info()->set_name("hyper_map");
|
||||
|
||||
AnfNodePtr ptrFnArg = nullptr;
|
||||
|
|
|
@ -278,10 +278,12 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
|
|||
// Convert class to Tuple
|
||||
// Convert getattr to getitem
|
||||
// Convert make_record to make_tuple
|
||||
void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(root);
|
||||
|
||||
bool changed = false;
|
||||
|
||||
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
||||
AnfNodeSet all_node = manager->all_nodes();
|
||||
for (auto &node : all_node) {
|
||||
|
@ -316,7 +318,9 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
|
||||
if (new_node != nullptr) {
|
||||
new_node->set_abstract(node->abstract());
|
||||
MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
|
||||
(void)manager->Replace(node, new_node);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -324,6 +328,7 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
auto ret = Reabs(node->abstract());
|
||||
node->set_abstract(ret);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
// expand tuples in graph parameters
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
|
||||
// Remove the class type from graphs
|
||||
void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
|
||||
// Remove most uses of tuples from the graph
|
||||
// tuples that are returned will be kept
|
||||
|
|
|
@ -38,13 +38,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod
|
|||
|
||||
// src type check
|
||||
auto src_type = src_->Type();
|
||||
if (src_type == nullptr) {
|
||||
if (src_type == nullptr || !src_type->isa<TensorType>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (src_type->isa<TensorType>()) {
|
||||
src_type = src_type->cast<TensorTypePtr>()->element();
|
||||
}
|
||||
src_type = src_type->cast<TensorTypePtr>()->element();
|
||||
|
||||
// tgt type check
|
||||
auto tgt_type = GetValueNode<TypePtr>(tgt_);
|
||||
|
|
|
@ -52,14 +52,16 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
opt::SimplifyDataStructures(func_graph, res->manager());
|
||||
bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
|
||||
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||
res->set_func_graph(new_fg);
|
||||
if (changed) {
|
||||
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||
res->set_func_graph(new_fg);
|
||||
}
|
||||
res->set_args_spec(args_spec);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -177,8 +177,8 @@ std::size_t FuncGraphAbstractClosure::hash() const {
|
|||
|
||||
std::string FuncGraphAbstractClosure::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << "FuncGraphAbstractClosure: " << this << "FuncGraph: " << func_graph_.get() << ", " << func_graph_->ToString()
|
||||
<< "; Context: " << context_.get() << context_->ToString();
|
||||
ss << "FuncGraphAbstractClosure: "
|
||||
<< "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -166,8 +166,9 @@ class PartialAbstractClosure : public AbstractFuncAtom {
|
|||
public:
|
||||
// Represents a partial application.
|
||||
// args_spec_list: The first few arguments of that function
|
||||
PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list)
|
||||
: fn_(fn), args_spec_list_(args_spec_list) {}
|
||||
PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodePtr &node = nullptr)
|
||||
: fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {}
|
||||
~PartialAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
|
@ -175,7 +176,11 @@ class PartialAbstractClosure : public AbstractFuncAtom {
|
|||
|
||||
AbstractFunctionPtr fn() { return fn_; }
|
||||
AbstractBasePtrList args() { return args_spec_list_; }
|
||||
AbstractFunctionPtr Copy() const override { return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_); }
|
||||
AnfNodePtr node() { return node_.lock(); }
|
||||
void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
|
||||
AbstractFunctionPtr Copy() const override {
|
||||
return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_, node_.lock());
|
||||
}
|
||||
bool operator==(const AbstractFunction &other) const override;
|
||||
std::size_t hash() const override;
|
||||
|
||||
|
@ -184,6 +189,8 @@ class PartialAbstractClosure : public AbstractFuncAtom {
|
|||
private:
|
||||
AbstractFuncAtomPtr fn_;
|
||||
AbstractBasePtrList args_spec_list_;
|
||||
// The CNode which this PartialAbstractClosure evaluated from.
|
||||
AnfNodeWeakPtr node_;
|
||||
};
|
||||
|
||||
class JTransformedAbstractClosure : public AbstractFuncAtom {
|
||||
|
|
|
@ -951,8 +951,19 @@ class PartialEvaluator : public Evaluator {
|
|||
if (args_conf_list.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Args size should be greater than 0";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
|
||||
auto arg0_value = args_conf_list[0]->GetEvaluatedValue();
|
||||
AbstractBasePtrList args_spec_list{arg0_value};
|
||||
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
|
||||
if (arg0_value->isa<AbstractError>()) {
|
||||
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
|
||||
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
|
||||
<< " as func is: " << arg0_value->ToString();
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
}
|
||||
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
|
||||
// Sometimes, node[0] in out_conf becomes phi0;
|
||||
if (func->isa<PrimitiveAbstractClosure>()) {
|
||||
|
@ -962,19 +973,26 @@ class PartialEvaluator : public Evaluator {
|
|||
return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
|
||||
}
|
||||
}
|
||||
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
|
||||
|
||||
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); });
|
||||
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
|
||||
|
||||
AbstractFuncAtomPtrList partialPtrList;
|
||||
auto build_partial = [args, &partialPtrList](const AbstractFuncAtomPtr &atom_func) {
|
||||
auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args);
|
||||
partialPtrList.push_back(new_func);
|
||||
auto cnode = out_conf->node()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != (args_conf_list.size() + 1)) {
|
||||
MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString()
|
||||
<< ", args_conf_list: " << mindspore::ToString(args_conf_list);
|
||||
}
|
||||
|
||||
AbstractFuncAtomPtrList partial_funcs_list;
|
||||
auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
|
||||
auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
|
||||
partial_funcs_list.push_back(new_func);
|
||||
};
|
||||
func->Visit(build_partial);
|
||||
|
||||
auto ret = AbstractFunction::MakeAbstractFunction(partialPtrList);
|
||||
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,9 @@
|
|||
#include "./common.h"
|
||||
#include "operator/ops.h"
|
||||
#include "operator/composite/do_signature.h"
|
||||
#include "pipeline/static_analysis/abstract_function.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/profile.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
|
@ -232,6 +234,13 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
return;
|
||||
}
|
||||
new_node->set_abstract(GetEvaluatedValueWrap(conf));
|
||||
if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
|
||||
auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
|
||||
if (partial_abstract->node() == node) {
|
||||
partial_abstract->set_node(new_node);
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
|
||||
|
||||
if (node->isa<CNode>()) {
|
||||
|
@ -383,6 +392,56 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
|
|||
return BuildValueNode(v, abs);
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
|
||||
auto new_inputs = new_node->inputs();
|
||||
AnfNodePtr func = new_inputs[0];
|
||||
AbstractBasePtr fnval = new_inputs[0]->abstract();
|
||||
|
||||
AbstractBasePtrList args;
|
||||
auto backed_fnval = fnval;
|
||||
if (fnval->isa<PartialAbstractClosure>()) {
|
||||
auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
|
||||
backed_fnval = partial_closure->fn();
|
||||
args = partial_closure->args();
|
||||
}
|
||||
std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args),
|
||||
[](const AnfNodePtr &inp) { return inp->abstract(); });
|
||||
|
||||
ScopeGuard scope_guard(new_node->scope());
|
||||
|
||||
auto specialized_node = BuildSpecializedNode(func, backed_fnval, args);
|
||||
auto wrapped_node = specialized_node;
|
||||
if (fnval->isa<PartialAbstractClosure>()) {
|
||||
auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
|
||||
AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)),
|
||||
specialized_node};
|
||||
auto anf_node = partial_closure->node();
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString();
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
if (cnode->size() != partial_closure->args().size() + 2) {
|
||||
MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
|
||||
<< " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
|
||||
}
|
||||
for (size_t i = 0; i < partial_closure->args().size(); i++) {
|
||||
auto old_node = cnode->input(i + 2);
|
||||
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]);
|
||||
if (possibile_value_node != nullptr) {
|
||||
partial_node_list.push_back(possibile_value_node);
|
||||
} else {
|
||||
if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) {
|
||||
MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString();
|
||||
}
|
||||
partial_node_list.push_back(old_node);
|
||||
}
|
||||
}
|
||||
wrapped_node = new_node->func_graph()->NewCNode(partial_node_list);
|
||||
wrapped_node->set_abstract(partial_closure);
|
||||
}
|
||||
return wrapped_node;
|
||||
}
|
||||
|
||||
const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
|
||||
auto cache_iter = evalcaches_.find(eval);
|
||||
if (cache_iter == evalcaches_.end()) {
|
||||
|
@ -465,6 +524,11 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
<< new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
|
||||
}
|
||||
|
||||
if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) {
|
||||
auto wrapped_node = BuildSpecializedParameterNode(new_node);
|
||||
new_inputs[0] = wrapped_node;
|
||||
}
|
||||
|
||||
if (CanSpecializeNode(func)) {
|
||||
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
|
||||
}
|
||||
|
@ -474,16 +538,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
if (CanSpecializeNode(args[i])) {
|
||||
new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
|
||||
}
|
||||
// support for partial(Multitype) which Multitype should not be inferred to POLY.
|
||||
// after one or more times clone, Multitype metafuncgraph evaluator will specialized to one type only,
|
||||
// so even with partial parameter, it will specialize to that graph.
|
||||
// Maybe a better idea should inline graph with partial node first, then it will have full
|
||||
// parameter list to infer and specialize.
|
||||
MS_EXCEPTION_IF_NULL(new_inputs[next]);
|
||||
if (new_inputs[next]->isa<ValueNode>() && (GetValueNode(new_inputs[next]) == kPolyNode) &&
|
||||
IsPrimitive(func, prim::kPrimPartial)) {
|
||||
new_inputs[next] = args[i];
|
||||
}
|
||||
i = next;
|
||||
}
|
||||
|
||||
|
|
|
@ -106,6 +106,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
// (disconnected).
|
||||
AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node);
|
||||
|
||||
// Build a value node from parameter if the function graph has special flag to hint it can be done.
|
||||
AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node);
|
||||
|
||||
// Build a value node if ival is constant and not any-value
|
||||
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival);
|
||||
// Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
|
||||
|
|
|
@ -87,11 +87,6 @@ class CumSumNet(nn.Cell):
|
|||
|
||||
|
||||
raise_set = [
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('TensorAdd0', {
|
||||
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('TensorAdd1', {
|
||||
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
|
||||
|
@ -271,11 +266,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Sub0', {
|
||||
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Sub1', {
|
||||
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
|
||||
|
@ -287,11 +277,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Mul0', {
|
||||
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Mul1', {
|
||||
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
|
||||
|
@ -352,11 +337,6 @@ raise_set = [
|
|||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Minimum0', {
|
||||
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Minimum1', {
|
||||
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
|
||||
|
@ -368,11 +348,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Maximum0', {
|
||||
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Maximum1', {
|
||||
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
|
||||
|
@ -384,11 +359,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('RealDiv0', {
|
||||
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('RealDiv1', {
|
||||
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
|
||||
|
@ -400,11 +370,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Div0', {
|
||||
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Div1', {
|
||||
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
|
||||
|
@ -416,11 +381,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('FloorDiv0', {
|
||||
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('FloorDiv1', {
|
||||
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
|
||||
|
@ -439,11 +399,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('FloorMod0', {
|
||||
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('FloorMod1', {
|
||||
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
|
||||
|
@ -462,11 +417,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Equal0', {
|
||||
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Equal1', {
|
||||
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
|
||||
|
@ -490,11 +440,6 @@ raise_set = [
|
|||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
|
||||
# input is not tensor
|
||||
('NotEqual0', {
|
||||
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('NotEqual1', {
|
||||
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
|
||||
|
@ -506,11 +451,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Greater0', {
|
||||
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Greater1', {
|
||||
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
|
||||
|
@ -522,11 +462,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('GreaterEqual0', {
|
||||
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('GreaterEqual1', {
|
||||
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
|
||||
|
@ -538,11 +473,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Less0', {
|
||||
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Less1', {
|
||||
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
|
||||
|
@ -554,11 +484,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('LessEqual0', {
|
||||
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('LessEqual1', {
|
||||
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
|
||||
|
@ -728,11 +653,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Atan20', {
|
||||
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Atan21', {
|
||||
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_hypermap_partial """
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def test_hypermap_specialize_param():
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.mul(x, y)
|
||||
return ret
|
||||
|
||||
factor1 = Tensor(5, dtype=mstype.int32)
|
||||
x = Tensor(np.ones([1]).astype(np.int32))
|
||||
y = Tensor(np.ones([2]).astype(np.int32))
|
||||
net = Net()
|
||||
hypermap = C.HyperMap()
|
||||
|
||||
@ms_function
|
||||
def hypermap_specialize_param():
|
||||
ret1 = hypermap(F.partial(net, factor1), (x, y))
|
||||
# List will be converted to Tuple in SimlifyDataStructurePass.
|
||||
ret2 = hypermap(F.partial(net, factor1), [x, y])
|
||||
return ret1, ret2
|
||||
|
||||
expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32)))
|
||||
ret = hypermap_specialize_param()
|
||||
assert(ret == (expected_ret, expected_ret))
|
Loading…
Reference in New Issue