forked from mindspore-Ecosystem/mindspore
refine code* refine code in bert model* add ToAbstruct for `FuncGraph`, `MetaFuncGraph` `Primitive`* remove partial hard code in spec for poly* remove any in data convert cache
This commit is contained in:
parent
21a5f06e93
commit
484d7f10c8
|
@ -42,8 +42,8 @@ class GetRefParamEliminater : public OptimizerCaller {
|
||||||
public:
|
public:
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
PatternNode<AnfNodePtr> x;
|
PatternNode<AnfNodePtr> x;
|
||||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
|
||||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -128,7 +128,8 @@ bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
|
||||||
std::vector<std::pair<std::string, ValuePtr>> key_values;
|
std::vector<std::pair<std::string, ValuePtr>> key_values;
|
||||||
for (auto item : dict_values) {
|
for (auto item : dict_values) {
|
||||||
if (!py::isinstance<py::str>(item.first)) {
|
if (!py::isinstance<py::str>(item.first)) {
|
||||||
MS_LOG(EXCEPTION) << "The key of dict is only support str.";
|
MS_LOG(ERROR) << "The key of dict is only support str.";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
std::string key = py::str(item.first);
|
std::string key = py::str(item.first);
|
||||||
ValuePtr out = nullptr;
|
ValuePtr out = nullptr;
|
||||||
|
@ -158,7 +159,7 @@ void ConvertDataClass(py::object obj, ValuePtr *const data) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
|
bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
|
||||||
MS_LOG(DEBUG) << "Converting primitive object";
|
MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
|
||||||
|
|
||||||
// need check the primitive is class type or instance
|
// need check the primitive is class type or instance
|
||||||
auto obj_type = data_converter::GetObjType(obj);
|
auto obj_type = data_converter::GetObjType(obj);
|
||||||
|
@ -184,6 +185,7 @@ bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature =
|
||||||
} else {
|
} else {
|
||||||
*data = primitive;
|
*data = primitive;
|
||||||
}
|
}
|
||||||
|
MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString();
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -389,12 +391,12 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
|
||||||
std::string obj_id = results[0] + python_mod_get_parse_method;
|
std::string obj_id = results[0] + python_mod_get_parse_method;
|
||||||
std::string obj_key = results[1];
|
std::string obj_key = results[1];
|
||||||
FuncGraphPtr func_graph = nullptr;
|
FuncGraphPtr func_graph = nullptr;
|
||||||
Any value = Any();
|
ValuePtr value = nullptr;
|
||||||
bool is_cache = data_converter::GetObjectValue(obj_id, &value);
|
bool is_cache = data_converter::GetObjectValue(obj_id, &value);
|
||||||
if (is_cache) {
|
if (is_cache) {
|
||||||
if (value.is<FuncGraphPtr>()) {
|
if (value && value->isa<FuncGraph>()) {
|
||||||
MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
|
MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
|
||||||
func_graph = value.cast<FuncGraphPtr>();
|
func_graph = value->cast<FuncGraphPtr>();
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -415,10 +417,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
namespace data_converter {
|
namespace data_converter {
|
||||||
static std::unordered_map<std::string, Any> object_map_ = std::unordered_map<std::string, Any>();
|
static std::unordered_map<std::string, ValuePtr> object_map_;
|
||||||
|
|
||||||
static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_ =
|
static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
|
||||||
std::unordered_map<std::string, std::vector<FuncGraphPtr>>();
|
|
||||||
|
|
||||||
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
|
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
|
||||||
object_graphs_map_[obj_key].push_back(data);
|
object_graphs_map_[obj_key].push_back(data);
|
||||||
|
@ -430,8 +431,8 @@ const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs()
|
||||||
return object_graphs_map_;
|
return object_graphs_map_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; }
|
void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
|
||||||
bool GetObjectValue(const std::string &obj_key, Any *const data) {
|
bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
|
||||||
if (object_map_.count(obj_key)) {
|
if (object_map_.count(obj_key)) {
|
||||||
*data = object_map_[obj_key];
|
*data = object_map_[obj_key];
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -32,8 +32,8 @@ namespace mindspore {
|
||||||
namespace parse {
|
namespace parse {
|
||||||
// data convert for parse
|
// data convert for parse
|
||||||
namespace data_converter {
|
namespace data_converter {
|
||||||
void CacheObjectValue(const std::string &obj_key, const Any &data);
|
void CacheObjectValue(const std::string &obj_key, const ValuePtr &data);
|
||||||
bool GetObjectValue(const std::string &obj_key, Any *const data);
|
bool GetObjectValue(const std::string &obj_key, ValuePtr *const data);
|
||||||
|
|
||||||
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data);
|
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data);
|
||||||
|
|
||||||
|
|
|
@ -82,6 +82,9 @@ std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecialize
|
||||||
if (iter != specializations_.end()) {
|
if (iter != specializations_.end()) {
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
if (context->func_graph()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Specialize inner error";
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -539,8 +542,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
||||||
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
|
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
|
||||||
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
|
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
|
||||||
if (status == kSpecializeFindUniqueArgvalPoly ||
|
if (status == kSpecializeFindUniqueArgvalPoly ||
|
||||||
(func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) ||
|
(func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
|
||||||
func->abstract()->isa<PartialAbstractClosure>()))) {
|
|
||||||
auto wrapped_node = BuildSpecializedParameterNode(new_node);
|
auto wrapped_node = BuildSpecializedParameterNode(new_node);
|
||||||
new_inputs[0] = wrapped_node;
|
new_inputs[0] = wrapped_node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "ir/manager.h"
|
#include "ir/manager.h"
|
||||||
#include "utils/ordered_set.h"
|
#include "utils/ordered_set.h"
|
||||||
#include "utils/convert_utils_base.h"
|
#include "utils/convert_utils_base.h"
|
||||||
|
#include "abstract/abstract_function.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
/*
|
/*
|
||||||
|
@ -48,6 +49,11 @@ FuncGraph::FuncGraph()
|
||||||
debug_info_ = std::make_shared<GraphDebugInfo>();
|
debug_info_ = std::make_shared<GraphDebugInfo>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr FuncGraph::ToAbstract() {
|
||||||
|
auto temp_context = abstract::AnalysisContext::DummyContext();
|
||||||
|
return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr FuncGraph::output() const {
|
AnfNodePtr FuncGraph::output() const {
|
||||||
// If return value is set, return should have two inputs.
|
// If return value is set, return should have two inputs.
|
||||||
if (return_ != nullptr && return_->inputs().size() == 2) {
|
if (return_ != nullptr && return_->inputs().size() == 2) {
|
||||||
|
|
|
@ -149,6 +149,7 @@ class FuncGraph : public FuncGraphBase {
|
||||||
|
|
||||||
// get the graph's abstract
|
// get the graph's abstract
|
||||||
abstract::AbstractFunctionPtr abstract();
|
abstract::AbstractFunctionPtr abstract();
|
||||||
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
|
||||||
// return the graph's output, or nullptr if not yet deduced
|
// return the graph's output, or nullptr if not yet deduced
|
||||||
AnfNodePtr output() const;
|
AnfNodePtr output() const;
|
||||||
|
|
|
@ -19,9 +19,15 @@
|
||||||
#include "ir/meta_func_graph.h"
|
#include "ir/meta_func_graph.h"
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "abstract/abstract_function.h"
|
||||||
|
|
||||||
// namespace to support intermediate representation definition
|
// namespace to support intermediate representation definition
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
|
||||||
|
return std::make_shared<abstract::MetaFuncGraphAbstractClosure>(shared_from_base<MetaFuncGraph>());
|
||||||
|
}
|
||||||
|
|
||||||
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
|
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
|
|
@ -49,7 +49,7 @@ class MetaFuncGraph : public FuncGraphBase {
|
||||||
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
|
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
|
||||||
return args_spec_list;
|
return args_spec_list;
|
||||||
}
|
}
|
||||||
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
const std::vector<Signature> &signatures() const { return signatures_; }
|
const std::vector<Signature> &signatures() const { return signatures_; }
|
||||||
void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
|
void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
|
||||||
// Generate a Graph for the given abstract arguments.
|
// Generate a Graph for the given abstract arguments.
|
||||||
|
|
|
@ -17,8 +17,15 @@
|
||||||
#include "ir/primitive.h"
|
#include "ir/primitive.h"
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include "abstract/abstract_function.h"
|
||||||
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr Primitive::ToAbstract() {
|
||||||
|
return std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
bool Primitive::operator==(const Value &other) const {
|
bool Primitive::operator==(const Value &other) const {
|
||||||
if (other.isa<Primitive>()) {
|
if (other.isa<Primitive>()) {
|
||||||
auto other_prim = static_cast<const Primitive &>(other);
|
auto other_prim = static_cast<const Primitive &>(other);
|
||||||
|
|
|
@ -57,7 +57,7 @@ class Primitive : public Named {
|
||||||
record_evaluate_add_attr_(false) {}
|
record_evaluate_add_attr_(false) {}
|
||||||
|
|
||||||
MS_DECLARE_PARENT(Primitive, Named);
|
MS_DECLARE_PARENT(Primitive, Named);
|
||||||
|
abstract::AbstractBasePtr ToAbstract();
|
||||||
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||||
std::string ToString() const override { return name(); }
|
std::string ToString() const override { return name(); }
|
||||||
void BeginRecordAddAttr() {
|
void BeginRecordAddAttr() {
|
||||||
|
|
|
@ -102,7 +102,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
||||||
def construct(self, data, label):
|
def construct(self, data, label):
|
||||||
out = self._backbone(data)
|
out = self._backbone(data)
|
||||||
label = F.mixed_precision_cast(mstype.float32, label)
|
label = F.mixed_precision_cast(mstype.float32, label)
|
||||||
return self._loss_fn(F.cast(out, mstype.float32), label)
|
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
|
||||||
|
|
||||||
validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
|
validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
|
||||||
if cast_model_type == mstype.float16:
|
if cast_model_type == mstype.float16:
|
||||||
|
|
|
@ -25,7 +25,8 @@ from mindspore import Tensor
|
||||||
from mindspore.nn.optim import AdamWeightDecay
|
from mindspore.nn.optim import AdamWeightDecay
|
||||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||||
from mindspore.nn import learning_rate_schedule as lr_schedules
|
from mindspore.nn import learning_rate_schedule as lr_schedules
|
||||||
from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
from mindspore.ops import operations as P
|
||||||
|
from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||||
from ...dataset_mock import MindData
|
from ...dataset_mock import MindData
|
||||||
from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph
|
from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph
|
||||||
|
|
||||||
|
@ -100,7 +101,7 @@ def get_config(version='base', batch_size=1):
|
||||||
|
|
||||||
|
|
||||||
class BertLearningRate(lr_schedules.LearningRateSchedule):
|
class BertLearningRate(lr_schedules.LearningRateSchedule):
|
||||||
def __init__(self, decay_steps, warmup_steps=0, learning_rate=0.1, end_learning_rate=0.0001, power=1.0):
|
def __init__(self, decay_steps, warmup_steps=100, learning_rate=0.1, end_learning_rate=0.0001, power=1.0):
|
||||||
super(BertLearningRate, self).__init__()
|
super(BertLearningRate, self).__init__()
|
||||||
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
|
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
|
||||||
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||||
|
|
|
@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
|
||||||
def __init__(self, length, max_relative_position):
|
def __init__(self, length, max_relative_position):
|
||||||
super(RelaPosMatrixGenerator, self).__init__()
|
super(RelaPosMatrixGenerator, self).__init__()
|
||||||
self._length = length
|
self._length = length
|
||||||
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
|
self._max_relative_position = max_relative_position
|
||||||
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
|
self._min_relative_position = -max_relative_position
|
||||||
self.range_length = -length + 1
|
self.range_length = -length + 1
|
||||||
|
|
||||||
self.tile = P.Tile()
|
self.tile = P.Tile()
|
||||||
|
@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
||||||
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
||||||
max_relative_position=max_relative_position)
|
max_relative_position=max_relative_position)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.one_hot = P.OneHot()
|
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
||||||
self.on_value = Tensor(1.0, mstype.float32)
|
|
||||||
self.off_value = Tensor(0.0, mstype.float32)
|
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.gather = P.GatherV2() # index_select
|
self.gather = P.GatherV2() # index_select
|
||||||
self.matmul = P.BatchMatMul()
|
self.matmul = P.BatchMatMul()
|
||||||
|
@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
||||||
if self.use_one_hot_embeddings:
|
if self.use_one_hot_embeddings:
|
||||||
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
||||||
one_hot_relative_positions_matrix = self.one_hot(
|
one_hot_relative_positions_matrix = self.one_hot(
|
||||||
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
|
flat_relative_positions_matrix)
|
||||||
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
||||||
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
||||||
embeddings = self.reshape(embeddings, my_shape)
|
embeddings = self.reshape(embeddings, my_shape)
|
||||||
|
@ -372,11 +370,11 @@ class SaturateCast(nn.Cell):
|
||||||
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
||||||
super(SaturateCast, self).__init__()
|
super(SaturateCast, self).__init__()
|
||||||
np_type = mstype.dtype_to_nptype(dst_type)
|
np_type = mstype.dtype_to_nptype(dst_type)
|
||||||
min_type = np.finfo(np_type).min
|
min_type = float(np.finfo(np_type).min)
|
||||||
max_type = np.finfo(np_type).max
|
max_type = float(np.finfo(np_type).max)
|
||||||
|
|
||||||
self.tensor_min_type = Tensor([min_type], dtype=src_type)
|
self.tensor_min_type = min_type
|
||||||
self.tensor_max_type = Tensor([max_type], dtype=src_type)
|
self.tensor_max_type = max_type
|
||||||
|
|
||||||
self.min_op = P.Minimum()
|
self.min_op = P.Minimum()
|
||||||
self.max_op = P.Maximum()
|
self.max_op = P.Maximum()
|
||||||
|
@ -442,7 +440,7 @@ class BertAttention(nn.Cell):
|
||||||
self.has_attention_mask = has_attention_mask
|
self.has_attention_mask = has_attention_mask
|
||||||
self.use_relative_positions = use_relative_positions
|
self.use_relative_positions = use_relative_positions
|
||||||
|
|
||||||
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
|
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape_from_2d = (-1, from_tensor_width)
|
self.shape_from_2d = (-1, from_tensor_width)
|
||||||
self.shape_to_2d = (-1, to_tensor_width)
|
self.shape_to_2d = (-1, to_tensor_width)
|
||||||
|
@ -471,7 +469,7 @@ class BertAttention(nn.Cell):
|
||||||
self.trans_shape = (0, 2, 1, 3)
|
self.trans_shape = (0, 2, 1, 3)
|
||||||
self.trans_shape_relative = (2, 0, 1, 3)
|
self.trans_shape_relative = (2, 0, 1, 3)
|
||||||
self.trans_shape_position = (1, 2, 0, 3)
|
self.trans_shape_position = (1, 2, 0, 3)
|
||||||
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
|
self.multiply_data = -10000.0
|
||||||
self.batch_num = batch_size * num_attention_heads
|
self.batch_num = batch_size * num_attention_heads
|
||||||
self.matmul = P.BatchMatMul()
|
self.matmul = P.BatchMatMul()
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
""" test nn ops """
|
""" test nn ops """
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.random import normal
|
from numpy.random import normal
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
|
@ -311,6 +312,7 @@ def test_op_with_arg_as_input():
|
||||||
|
|
||||||
# The partial application used as argument is not supported yet
|
# The partial application used as argument is not supported yet
|
||||||
# because of the limit of inference specialize system
|
# because of the limit of inference specialize system
|
||||||
|
@pytest.mark.skip("poly in infer")
|
||||||
def test_partial_as_arg():
|
def test_partial_as_arg():
|
||||||
class PartialArgNet(nn.Cell):
|
class PartialArgNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue