!3356 [common]Refine code & Add ToAbstruct for `FuncGraph`, `MetaFuncGraph` `Primitive`

Merge pull request !3356 from vlne-v1/add_to_abstract_for_function
This commit is contained in:
mindspore-ci-bot 2020-07-23 14:33:58 +08:00 committed by Gitee
commit 48711414fc
14 changed files with 57 additions and 33 deletions

View File

@ -42,8 +42,8 @@ class GetRefParamEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x);
return nullptr;
}
};

View File

@ -128,7 +128,8 @@ bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
std::vector<std::pair<std::string, ValuePtr>> key_values;
for (auto item : dict_values) {
if (!py::isinstance<py::str>(item.first)) {
MS_LOG(EXCEPTION) << "The key of dict is only support str.";
MS_LOG(ERROR) << "The key of dict is only support str.";
return false;
}
std::string key = py::str(item.first);
ValuePtr out = nullptr;
@ -158,7 +159,7 @@ void ConvertDataClass(py::object obj, ValuePtr *const data) {
}
bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
MS_LOG(DEBUG) << "Converting primitive object";
MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
// need check the primitive is class type or instance
auto obj_type = data_converter::GetObjType(obj);
@ -184,6 +185,7 @@ bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature =
} else {
*data = primitive;
}
MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString();
}
return true;
}
@ -389,12 +391,12 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
std::string obj_id = results[0] + python_mod_get_parse_method;
std::string obj_key = results[1];
FuncGraphPtr func_graph = nullptr;
Any value = Any();
ValuePtr value = nullptr;
bool is_cache = data_converter::GetObjectValue(obj_id, &value);
if (is_cache) {
if (value.is<FuncGraphPtr>()) {
if (value && value->isa<FuncGraph>()) {
MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
func_graph = value.cast<FuncGraphPtr>();
func_graph = value->cast<FuncGraphPtr>();
return func_graph;
}
}
@ -415,10 +417,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
return func_graph;
}
namespace data_converter {
static std::unordered_map<std::string, Any> object_map_ = std::unordered_map<std::string, Any>();
static std::unordered_map<std::string, ValuePtr> object_map_;
static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_ =
std::unordered_map<std::string, std::vector<FuncGraphPtr>>();
static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
object_graphs_map_[obj_key].push_back(data);
@ -430,8 +431,8 @@ const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs()
return object_graphs_map_;
}
void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; }
bool GetObjectValue(const std::string &obj_key, Any *const data) {
void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
if (object_map_.count(obj_key)) {
*data = object_map_[obj_key];
return true;

View File

@ -32,8 +32,8 @@ namespace mindspore {
namespace parse {
// data convert for parse
namespace data_converter {
void CacheObjectValue(const std::string &obj_key, const Any &data);
bool GetObjectValue(const std::string &obj_key, Any *const data);
void CacheObjectValue(const std::string &obj_key, const ValuePtr &data);
bool GetObjectValue(const std::string &obj_key, ValuePtr *const data);
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data);

View File

@ -82,6 +82,9 @@ std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecialize
if (iter != specializations_.end()) {
return iter->second;
}
if (context->func_graph()) {
MS_LOG(EXCEPTION) << "Specialize inner error";
}
return nullptr;
}
@ -540,8 +543,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
if (status == kSpecializeFindUniqueArgvalPoly ||
(func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) ||
func->abstract()->isa<PartialAbstractClosure>()))) {
(func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
new_inputs[0] = wrapped_node;
}

View File

@ -26,6 +26,7 @@
#include "ir/manager.h"
#include "utils/ordered_set.h"
#include "utils/convert_utils_base.h"
#include "abstract/abstract_function.h"
namespace mindspore {
/*
@ -48,6 +49,11 @@ FuncGraph::FuncGraph()
debug_info_ = std::make_shared<GraphDebugInfo>();
}
abstract::AbstractBasePtr FuncGraph::ToAbstract() {
auto temp_context = abstract::AnalysisContext::DummyContext();
return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
}
AnfNodePtr FuncGraph::output() const {
// If return value is set, return should have two inputs.
if (return_ != nullptr && return_->inputs().size() == 2) {

View File

@ -149,6 +149,7 @@ class FuncGraph : public FuncGraphBase {
// get the graph's abstract
abstract::AbstractFunctionPtr abstract();
abstract::AbstractBasePtr ToAbstract() override;
// return the graph's output, or nullptr if not yet deduced
AnfNodePtr output() const;

View File

@ -19,9 +19,15 @@
#include "ir/meta_func_graph.h"
#include "base/core_ops.h"
#include "utils/context/ms_context.h"
#include "abstract/abstract_function.h"
// namespace to support intermediate representation definition
namespace mindspore {
abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
return std::make_shared<abstract::MetaFuncGraphAbstractClosure>(shared_from_base<MetaFuncGraph>());
}
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);

View File

@ -49,7 +49,7 @@ class MetaFuncGraph : public FuncGraphBase {
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
return args_spec_list;
}
abstract::AbstractBasePtr ToAbstract() override;
const std::vector<Signature> &signatures() const { return signatures_; }
void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
// Generate a Graph for the given abstract arguments.

View File

@ -17,8 +17,15 @@
#include "ir/primitive.h"
#include <utility>
#include "abstract/abstract_function.h"
namespace mindspore {
abstract::AbstractBasePtr Primitive::ToAbstract() {
return std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), nullptr);
}
bool Primitive::operator==(const Value &other) const {
if (other.isa<Primitive>()) {
auto other_prim = static_cast<const Primitive &>(other);

View File

@ -57,7 +57,7 @@ class Primitive : public Named {
record_evaluate_add_attr_(false) {}
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToAbstract();
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); }
void BeginRecordAddAttr() {

View File

@ -102,7 +102,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
def construct(self, data, label):
out = self._backbone(data)
label = F.mixed_precision_cast(mstype.float32, label)
return self._loss_fn(F.cast(out, mstype.float32), label)
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
if cast_model_type == mstype.float16:

View File

@ -25,7 +25,8 @@ from mindspore import Tensor
from mindspore.nn.optim import AdamWeightDecay
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.nn import learning_rate_schedule as lr_schedules
from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from mindspore.ops import operations as P
from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from ...dataset_mock import MindData
from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph
@ -100,7 +101,7 @@ def get_config(version='base', batch_size=1):
class BertLearningRate(lr_schedules.LearningRateSchedule):
def __init__(self, decay_steps, warmup_steps=0, learning_rate=0.1, end_learning_rate=0.0001, power=1.0):
def __init__(self, decay_steps, warmup_steps=100, learning_rate=0.1, end_learning_rate=0.0001, power=1.0):
super(BertLearningRate, self).__init__()
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)

View File

@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
def __init__(self, length, max_relative_position):
super(RelaPosMatrixGenerator, self).__init__()
self._length = length
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
self._max_relative_position = max_relative_position
self._min_relative_position = -max_relative_position
self.range_length = -length + 1
self.tile = P.Tile()
@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
max_relative_position=max_relative_position)
self.reshape = P.Reshape()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.one_hot = nn.OneHot(depth=self.vocab_size)
self.shape = P.Shape()
self.gather = P.GatherV2() # index_select
self.matmul = P.BatchMatMul()
@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
if self.use_one_hot_embeddings:
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
one_hot_relative_positions_matrix = self.one_hot(
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
flat_relative_positions_matrix)
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
embeddings = self.reshape(embeddings, my_shape)
@ -372,11 +370,11 @@ class SaturateCast(nn.Cell):
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(SaturateCast, self).__init__()
np_type = mstype.dtype_to_nptype(dst_type)
min_type = np.finfo(np_type).min
max_type = np.finfo(np_type).max
min_type = float(np.finfo(np_type).min)
max_type = float(np.finfo(np_type).max)
self.tensor_min_type = Tensor([min_type], dtype=src_type)
self.tensor_max_type = Tensor([max_type], dtype=src_type)
self.tensor_min_type = min_type
self.tensor_max_type = max_type
self.min_op = P.Minimum()
self.max_op = P.Maximum()
@ -442,7 +440,7 @@ class BertAttention(nn.Cell):
self.has_attention_mask = has_attention_mask
self.use_relative_positions = use_relative_positions
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
self.reshape = P.Reshape()
self.shape_from_2d = (-1, from_tensor_width)
self.shape_to_2d = (-1, to_tensor_width)
@ -471,7 +469,7 @@ class BertAttention(nn.Cell):
self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
self.multiply_data = -10000.0
self.batch_num = batch_size * num_attention_heads
self.matmul = P.BatchMatMul()

View File

@ -15,6 +15,7 @@
""" test nn ops """
import numpy as np
from numpy.random import normal
import pytest
import mindspore.nn as nn
import mindspore.context as context
@ -311,6 +312,7 @@ def test_op_with_arg_as_input():
# The partial application used as argument is not supported yet
# because of the limit of inference specialize system
@pytest.mark.skip("poly in infer")
def test_partial_as_arg():
class PartialArgNet(nn.Cell):
def __init__(self):