forked from mindspore-Ecosystem/mindspore
rename env_getitem/setitem to EnvironGet/EnvironGet
This commit is contained in:
parent
617965c751
commit
19e9571149
|
@ -587,7 +587,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
|
|||
|
||||
std::vector<AnfNodePtr> grads;
|
||||
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
grads.push_back(NewValueNode(newenv));
|
||||
grads.push_back(NewEnviron(b));
|
||||
for (int64_t i = 0; i < tuple_size; ++i) {
|
||||
grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
|
||||
}
|
||||
|
@ -628,7 +628,7 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args
|
|||
|
||||
std::vector<AnfNodePtr> grads;
|
||||
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
grads.push_back(NewValueNode(newenv));
|
||||
grads.push_back(NewEnviron(b));
|
||||
for (int64_t i = 0; i < list_size; ++i) {
|
||||
grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
|
||||
}
|
||||
|
|
|
@ -128,7 +128,7 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
|||
fv_adjoint->second->RegisterKUser(default_val_node, 1);
|
||||
anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
|
||||
}
|
||||
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
|
||||
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvironGet), din, embed_node, default_val_node});
|
||||
MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
|
||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||
MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
|
||||
|
@ -421,7 +421,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
|
|||
fv_adjoint->second->RegisterKUser(node, 1);
|
||||
auto sens = fv_adjoint->second->dout();
|
||||
new_grad_fv = tape_->NewCNode({
|
||||
NewValueNode(prim::kPrimEnvSetItem),
|
||||
NewValueNode(prim::kPrimEnvironSet),
|
||||
new_grad_fv,
|
||||
node,
|
||||
sens,
|
||||
|
@ -448,7 +448,7 @@ AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
|
|||
fv_adjoint.second->RegisterKUser(node, 1);
|
||||
auto sens = fv_adjoint.second->dout();
|
||||
new_grad_fv = tape_->NewCNode({
|
||||
NewValueNode(prim::kPrimEnvSetItem),
|
||||
NewValueNode(prim::kPrimEnvironSet),
|
||||
new_grad_fv,
|
||||
node,
|
||||
sens,
|
||||
|
@ -486,9 +486,9 @@ void DFunctor::MapMorphism() {
|
|||
// Set output for tape closure.
|
||||
AnfNodePtr grad_fv;
|
||||
if (lift_fv_before_grad) {
|
||||
grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
|
||||
grad_fv = AttachFvDoutToTape(NewEnviron(tape_));
|
||||
} else {
|
||||
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
|
||||
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewEnviron(tape_)));
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
|
||||
|
|
|
@ -115,7 +115,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
|||
mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
|
||||
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
|
||||
mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_indirect_fv_;
|
||||
// Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvGetItemTransform in optimizer
|
||||
// Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvironGetTransform in optimizer
|
||||
// can hit its cache if fv_node is same.
|
||||
mindspore::HashMap<AnfNodePtr, std::pair<CNodePtr, CNodePtr>> anfnode_to_envitem_;
|
||||
FuncGraphPtr primal_graph_;
|
||||
|
|
|
@ -611,7 +611,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
|
|||
|
||||
std::vector<AnfNodePtr> args;
|
||||
args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
args.push_back(NewValueNode(newenv));
|
||||
args.push_back(NewEnviron(bprop_fg));
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
if (!extra_args.empty()) {
|
||||
args.insert(args.end(), extra_args.cbegin(), extra_args.cend());
|
||||
|
@ -622,7 +622,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
|
|||
// Set bprop output as (env, dx)
|
||||
std::string model_name("mindspore.ops.composite.multitype_ops.add_impl");
|
||||
std::string python_ops("_tuple_add");
|
||||
auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg);
|
||||
auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewEnviron(bprop_fg)}, bprop_fg);
|
||||
auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
|
||||
if (!extra_args.empty()) {
|
||||
extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple));
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/environ_conversion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "utils/flags.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
using SymbolicKeyConversionMap =
|
||||
mindspore::HashMap<SymbolicKeyInstancePtr, int64_t, SymbolicKeyInstanceHash, SymbolicKeyInstanceEqual>;
|
||||
|
||||
namespace {
|
||||
bool IsAbstractEnvType(const abstract::AbstractBasePtr &abs) {
|
||||
if (abs != nullptr && abs->isa<abstract::AbstractScalar>() && abs->GetTypeTrack()->isa<EnvType>()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr TransformAbstractRecursively(const abstract::AbstractBasePtr &orig_abs,
|
||||
const abstract::AbstractBasePtr target_abs) {
|
||||
if (orig_abs == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (IsAbstractEnvType(orig_abs)) {
|
||||
return target_abs;
|
||||
}
|
||||
if (!orig_abs->isa<abstract::AbstractSequence>()) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto &abs_seq = orig_abs->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_seq);
|
||||
const auto &elements = abs_seq->elements();
|
||||
abstract::AbstractBasePtrList new_elements;
|
||||
bool transformed = false;
|
||||
for (auto elem : elements) {
|
||||
if (elem->isa<abstract::AbstractSequence>()) {
|
||||
auto inner_abs_seq = elem->cast<abstract::AbstractSequencePtr>();
|
||||
auto transformed_abs = TransformAbstractRecursively(inner_abs_seq, target_abs);
|
||||
if (transformed_abs != nullptr) {
|
||||
new_elements.push_back(transformed_abs);
|
||||
transformed = true;
|
||||
} else {
|
||||
new_elements.push_back(elem);
|
||||
}
|
||||
} else if (IsAbstractEnvType(elem)) {
|
||||
new_elements.push_back(target_abs);
|
||||
transformed = true;
|
||||
} else {
|
||||
new_elements.push_back(elem);
|
||||
}
|
||||
}
|
||||
if (transformed) {
|
||||
abstract::AbstractBasePtr new_abs;
|
||||
if (abs_seq->isa<abstract::AbstractTuple>()) {
|
||||
new_abs = std::make_shared<abstract::AbstractTuple>(new_elements);
|
||||
} else if (abs_seq->isa<abstract::AbstractList>()) {
|
||||
new_abs = std::make_shared<abstract::AbstractList>(new_elements);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "abs_seq is not AbstractTuple or AbstractList, but: " << abs_seq->ToString();
|
||||
}
|
||||
return new_abs;
|
||||
}
|
||||
// No transformation.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void TransformNodeAbstractIfEnvType(const AnfNodePtr &node, const abstract::AbstractBasePtr &target_abs) {
|
||||
auto transformed_abs = TransformAbstractRecursively(node->abstract(), target_abs);
|
||||
if (transformed_abs != nullptr) {
|
||||
node->set_abstract(transformed_abs);
|
||||
}
|
||||
}
|
||||
|
||||
TypeId GetValueType(const CNodePtr &cnode) {
|
||||
// (EnvironSet/EnvironGet, environ, key, value/default)
|
||||
if (cnode->inputs().size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "EnvrionSet/EnvironGet cnode should have 4 inputs, but: " << cnode->DebugString();
|
||||
}
|
||||
const auto &value_abstract = cnode->input(3)->abstract();
|
||||
if (value_abstract == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "4th input of EnvironSet/EnvironGet cnode should abstract, but not set, node: "
|
||||
<< cnode->DebugString();
|
||||
}
|
||||
if (IsAbstractEnvType(value_abstract)) {
|
||||
return kObjectTypeEnvType;
|
||||
} else {
|
||||
return kObjectTypeTensorType;
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr GetTransformedKeyNode(const AnfNodePtr &old_key_node, SymbolicKeyConversionMap *para_symbol_map) {
|
||||
const auto &symbolic_key_inst = GetValueNode<SymbolicKeyInstancePtr>(old_key_node);
|
||||
uint64_t transformed_key = 0;
|
||||
auto &symbolic_key_map = *para_symbol_map;
|
||||
auto iter = symbolic_key_map.find(symbolic_key_inst);
|
||||
if (iter != symbolic_key_map.end()) {
|
||||
transformed_key = iter->second;
|
||||
} else {
|
||||
static uint64_t key_counter = 0;
|
||||
transformed_key = ++key_counter;
|
||||
symbolic_key_map.emplace(std::make_pair(symbolic_key_inst, transformed_key));
|
||||
}
|
||||
auto tensor_key = std::make_shared<mindspore::tensor::Tensor>(transformed_key);
|
||||
auto transformed_key_node = NewValueNode(tensor_key);
|
||||
transformed_key_node->set_abstract(tensor_key->ToAbstract());
|
||||
return transformed_key_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool EnvironConversion(const pipeline::ResourcePtr &resource) {
|
||||
SymbolicKeyConversionMap symbolic_key_map;
|
||||
static AbstractBasePtr scalar_abs = std::make_shared<abstract::AbstractScalar>(kAnyValue, kInt64);
|
||||
static AbstractBasePtr tensor_abs = std::make_shared<abstract::AbstractTensor>(scalar_abs);
|
||||
static std::string attr_name = "value_type";
|
||||
const int kPrimitiveOffset = 0;
|
||||
const int kSymbolicKeyOffset = 2;
|
||||
auto mng = resource->manager();
|
||||
const auto &all_nodes = mng->all_nodes();
|
||||
auto txn = mng->Transact();
|
||||
for (const auto &node : all_nodes) {
|
||||
TransformNodeAbstractIfEnvType(node, tensor_abs);
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimEnvironSet) && !IsPrimitiveCNode(node, prim::kPrimEnvironGet)) {
|
||||
continue;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
// Prim
|
||||
AnfNodePtr transformed_prim_node;
|
||||
const auto &type_id = GetValueType(cnode);
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(kPrimitiveOffset));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const auto &old_attr = prim->GetAttr(attr_name);
|
||||
if (old_attr != nullptr) {
|
||||
const auto old_attr_value = GetValue<int>(old_attr);
|
||||
if (old_attr_value != type_id) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimEnvironSet)) {
|
||||
prim = std::make_shared<Primitive>(prim::kEnvironSet);
|
||||
} else {
|
||||
prim = std::make_shared<Primitive>(prim::kEnvironGet);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_attr(attr_name, MakeValue(static_cast<int>(type_id)));
|
||||
transformed_prim_node = NewValueNode(prim);
|
||||
txn.SetEdge(node, kPrimitiveOffset, transformed_prim_node);
|
||||
}
|
||||
} else {
|
||||
prim->set_attr(attr_name, MakeValue(static_cast<int>(type_id)));
|
||||
}
|
||||
// Abstract of Environ & Value will be set by previous TransformNodeAbstract function.
|
||||
// Key
|
||||
if (!IsValueNode<SymbolicKeyInstance>(cnode->input(kSymbolicKeyOffset))) {
|
||||
MS_LOG(EXCEPTION) << "should be SymbolicKey, but: " << cnode->input(kSymbolicKeyOffset)->ToString();
|
||||
}
|
||||
const auto &transformed_key_node = GetTransformedKeyNode(cnode->input(kSymbolicKeyOffset), &symbolic_key_map);
|
||||
txn.SetEdge(node, kSymbolicKeyOffset, transformed_key_node);
|
||||
}
|
||||
|
||||
txn.Commit();
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,6 +1,4 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -16,24 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/symbolic.h"
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &objPtr) {
|
||||
out << "()";
|
||||
return out;
|
||||
}
|
||||
|
||||
bool EnvInstance::operator==(const EnvInstance &other) const { return true; }
|
||||
|
||||
bool EnvInstance::operator==(const Value &other) const {
|
||||
if (other.isa<EnvInstance>()) {
|
||||
auto other_env_inst = static_cast<const EnvInstance *>(&other);
|
||||
return *this == *other_env_inst;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
std::shared_ptr<EnvInstance> newenv = std::make_shared<EnvInstance>();
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
bool EnvironConversion(const pipeline::ResourcePtr &resource);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_
|
|
@ -19,7 +19,7 @@
|
|||
#include "frontend/optimizer/irpass/branch_culling.h"
|
||||
#include "frontend/optimizer/irpass/cast_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/convert.h"
|
||||
#include "frontend/optimizer/irpass/env_item_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/environ_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/grad_var_prepare.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/inline.h"
|
||||
|
@ -119,26 +119,28 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);
|
||||
real_op_eliminate_ = MakeSubstitution(std::make_shared<RealOpEliminate>(), "real_op_eliminate", prim::kPrimRealInner);
|
||||
|
||||
// Env Item Eliminate
|
||||
env_get_item_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
|
||||
env_get_item_add_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<EnvGetItemAddEliminater>(), "env_get_item_add_eliminate_", prim::kPrimEnvGetItem);
|
||||
env_get_set_item_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<EnvGetSetItemEliminater>(), "env_get_set_item_eliminate", prim::kPrimEnvGetItem);
|
||||
env_get_item_depend_swap_ =
|
||||
MakeSubstitution(std::make_shared<EnvGetItemDependSwap>(), "env_get_item_depend_swap", prim::kPrimEnvGetItem);
|
||||
// Environ Item Eliminate
|
||||
environ_get_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<EnvironGetEliminater>(), "environ_get_eliminate", prim::kPrimEnvironGet);
|
||||
environ_get_add_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<EnvironGetAddEliminater>(), "environ_get_add_eliminate_", prim::kPrimEnvironGet);
|
||||
environ_get_set_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<EnvironGetSetEliminater>(), "environ_get_set_eliminate", prim::kPrimEnvironGet);
|
||||
environ_get_depend_swap_ =
|
||||
MakeSubstitution(std::make_shared<EnvironGetDependSwap>(), "environ_get_depend_swap", prim::kPrimEnvironGet);
|
||||
environ_add_const_eliminate_ = MakeSubstitution(std::make_shared<EnvironAddConstEliminater>(),
|
||||
"environ_add_const_eliminate_", prim::kPrimEnvironAdd);
|
||||
|
||||
incorporate_env_getitem_bypass_recursive_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
|
||||
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
|
||||
incorporate_environ_get_bypass_recursive_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvironGet>(true), "incorporate_environ_get", prim::kPrimEnvironGet);
|
||||
incorporate_environ_get_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvironGetSwitch>(),
|
||||
"incorporate_environ_get_switch", prim::kPrimEnvironGet);
|
||||
incorporate_environ_get_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvironGet>(), "incorporate_environ_get", prim::kPrimEnvironGet);
|
||||
|
||||
incorporate_env_getitem_switch_layer_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitchLayer>(), "incorporate_env_getitem_switch_layer",
|
||||
prim::kPrimEnvGetItem);
|
||||
incorporate_environ_get_switch_layer_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvironGetSwitchLayer>(), "incorporate_environ_get_switch_layer",
|
||||
prim::kPrimEnvironGet);
|
||||
|
||||
// Ref eliminate
|
||||
make_ref_eliminate_ =
|
||||
|
@ -157,8 +159,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
|
||||
float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
|
||||
"float_tuple_getitem_switch", prim::kPrimTupleGetItem);
|
||||
float_env_getitem_switch_ =
|
||||
MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
|
||||
float_environ_get_switch_ =
|
||||
MakeSubstitution(std::make_shared<FloatEnvironGetSwitch>(), "float_environ_get_switch", prim::kPrimEnvironGet);
|
||||
exchange_switch_depend_value_ =
|
||||
MakeSubstitution(std::make_shared<ExchangeSwitchDependValue>(), "exchange_switch_depend_value", prim::kPrimSwitch);
|
||||
|
||||
|
|
|
@ -65,14 +65,15 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr real_op_eliminate_;
|
||||
|
||||
// Env Item Eliminate
|
||||
SubstitutionPtr env_get_item_eliminate_;
|
||||
SubstitutionPtr env_get_item_add_eliminate_;
|
||||
SubstitutionPtr env_get_set_item_eliminate_;
|
||||
SubstitutionPtr env_get_item_depend_swap_;
|
||||
SubstitutionPtr incorporate_env_getitem_;
|
||||
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
|
||||
SubstitutionPtr incorporate_env_getitem_switch_;
|
||||
SubstitutionPtr incorporate_env_getitem_switch_layer_;
|
||||
SubstitutionPtr environ_get_eliminate_;
|
||||
SubstitutionPtr environ_get_add_eliminate_;
|
||||
SubstitutionPtr environ_get_set_eliminate_;
|
||||
SubstitutionPtr environ_get_depend_swap_;
|
||||
SubstitutionPtr environ_add_const_eliminate_;
|
||||
SubstitutionPtr incorporate_environ_get_;
|
||||
SubstitutionPtr incorporate_environ_get_bypass_recursive_;
|
||||
SubstitutionPtr incorporate_environ_get_switch_;
|
||||
SubstitutionPtr incorporate_environ_get_switch_layer_;
|
||||
|
||||
// Ref eliminate
|
||||
SubstitutionPtr make_ref_eliminate_;
|
||||
|
@ -84,7 +85,7 @@ class OptimizeIRPassLib {
|
|||
// Branch culling
|
||||
SubstitutionPtr switch_simplify_;
|
||||
SubstitutionPtr float_tuple_getitem_switch_;
|
||||
SubstitutionPtr float_env_getitem_switch_;
|
||||
SubstitutionPtr float_environ_get_switch_;
|
||||
SubstitutionPtr exchange_switch_depend_value_;
|
||||
|
||||
SubstitutionPtr switch_partial_eliminater_;
|
||||
|
|
|
@ -58,8 +58,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
{prim::kPrimMomentum, {2, 3}},
|
||||
{prim::kPrimStateSetItem, {1}},
|
||||
{prim::kPrimTupleGetItem, {2}},
|
||||
{prim::kPrimEnvGetItem, {1}},
|
||||
{prim::kPrimEnvSetItem, {1}},
|
||||
{prim::kPrimEnvironGet, {1}},
|
||||
{prim::kPrimEnvironSet, {1}},
|
||||
{prim::kPrimReduceSum, {2}},
|
||||
{prim::kPrimReduceMean, {2}},
|
||||
{prim::kPrimReduceAll, {2}},
|
||||
|
@ -83,7 +83,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
|
||||
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
|
||||
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
|
||||
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
|
||||
{prim::kPrimEnvironGet, {1}}, {prim::kPrimEnvironSet, {1}},
|
||||
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
|
||||
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
|
||||
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
|
||||
|
|
|
@ -66,16 +66,16 @@ class FloatTupleGetItemSwitch : public OptimizerCaller {
|
|||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
|
||||
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
|
||||
class FloatEnvGetItemSwitch : public OptimizerCaller {
|
||||
// {prim::kPrimEnvironGet, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
|
||||
// {prim::kPrimSwitch, X1, {prim::kPrimEnvironGet, X2, X4, X5}, {prim::kPrimEnvironGet, X3, X4, X5}}
|
||||
class FloatEnvironGetSwitch : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
|
||||
MATCH_REPLACE(node,
|
||||
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
|
||||
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
|
||||
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)));
|
||||
PPrimitive(prim::kPrimEnvironGet, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
|
||||
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvironGet, true_br, x, x2),
|
||||
PPrimitive(prim::kPrimEnvironGet, false_br, x, x2)));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
@ -37,10 +37,15 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace irpass {
|
||||
namespace internal {
|
||||
class EnvGetitemTransform {
|
||||
constexpr int kEnvironGetSetInputSize = 4;
|
||||
constexpr int kEnvironOffset = 1;
|
||||
constexpr int kSymbolicKeyOffset = 2;
|
||||
constexpr int kValueOffset = 3;
|
||||
|
||||
class EnvironGetTransform {
|
||||
public:
|
||||
EnvGetitemTransform() : cache_() {}
|
||||
~EnvGetitemTransform() = default;
|
||||
EnvironGetTransform() : cache_() {}
|
||||
~EnvironGetTransform() = default;
|
||||
|
||||
FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
|
||||
if (cache_.find(fg) == cache_.end()) {
|
||||
|
@ -56,22 +61,22 @@ class EnvGetitemTransform {
|
|||
}
|
||||
|
||||
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
|
||||
auto env = new_fg->output();
|
||||
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
|
||||
// {prim::kPrimEnvSetItem, env, symbolickey, value}
|
||||
auto &inputs = env->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != 4) {
|
||||
auto environ_node = new_fg->output();
|
||||
while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) {
|
||||
// {prim::kPrimEnvironSet, environ, symbolickey, value}
|
||||
auto &inputs = environ_node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != kEnvironGetSetInputSize) {
|
||||
MS_LOG(WARNING) << "Input size should be 4";
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
|
||||
if (!IsValueNode<SymbolicKeyInstance>(inputs[kSymbolicKeyOffset])) {
|
||||
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
env = inputs[1];
|
||||
auto value = inputs[3];
|
||||
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
|
||||
environ_node = inputs[kEnvironOffset];
|
||||
auto value = inputs[kValueOffset];
|
||||
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[kSymbolicKeyOffset]);
|
||||
if (*key2 == *key) {
|
||||
new_fg->set_output(value);
|
||||
cache[hash_key] = new_fg;
|
||||
|
@ -79,7 +84,8 @@ class EnvGetitemTransform {
|
|||
return new_fg;
|
||||
}
|
||||
}
|
||||
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
|
||||
new_fg->set_output(
|
||||
new_fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, NewValueNode(key), default_node}));
|
||||
cache[hash_key] = new_fg;
|
||||
}
|
||||
|
||||
|
@ -92,10 +98,10 @@ class EnvGetitemTransform {
|
|||
cache_;
|
||||
};
|
||||
|
||||
class EnvGetitemTransformACrossGraph {
|
||||
class EnvironGetTransformACrossGraph {
|
||||
public:
|
||||
EnvGetitemTransformACrossGraph() : cache_() {}
|
||||
~EnvGetitemTransformACrossGraph() = default;
|
||||
EnvironGetTransformACrossGraph() : cache_() {}
|
||||
~EnvironGetTransformACrossGraph() = default;
|
||||
|
||||
FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
|
||||
if (cache_.find(fg) == cache_.end()) {
|
||||
|
@ -120,29 +126,30 @@ class EnvGetitemTransformACrossGraph {
|
|||
auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str()));
|
||||
new_fg_outer->set_output(NewValueNode(new_fg));
|
||||
|
||||
auto env = new_fg->output();
|
||||
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
|
||||
// {prim::kPrimEnvSetItem, env, symbolickey, value}
|
||||
auto &inputs = env->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != 4) {
|
||||
auto environ_node = new_fg->output();
|
||||
while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) {
|
||||
// {prim::kPrimEnvironSet, environ, symbolickey, value}
|
||||
auto &inputs = environ_node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != kEnvironGetSetInputSize) {
|
||||
MS_LOG(WARNING) << "Input size should be 4";
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
|
||||
if (!IsValueNode<SymbolicKeyInstance>(inputs[kSymbolicKeyOffset])) {
|
||||
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
env = inputs[1];
|
||||
auto value = inputs[3];
|
||||
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
|
||||
environ_node = inputs[kEnvironOffset];
|
||||
auto value = inputs[kValueOffset];
|
||||
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[kSymbolicKeyOffset]);
|
||||
if (*key2 == *key) {
|
||||
new_fg->set_output(value);
|
||||
cache[hash_key] = new_fg_outer;
|
||||
return new_fg_outer;
|
||||
}
|
||||
}
|
||||
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
|
||||
new_fg->set_output(
|
||||
new_fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, NewValueNode(key), default_node}));
|
||||
cache[hash_key] = new_fg_outer;
|
||||
}
|
||||
|
||||
|
@ -156,48 +163,48 @@ class EnvGetitemTransformACrossGraph {
|
|||
};
|
||||
} // namespace internal
|
||||
|
||||
// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
|
||||
class EnvGetItemEliminater : public AnfVisitor {
|
||||
// {prim::kPrimEnvironGet, C1, C2, Y} -> Y
|
||||
class EnvironGetEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode c1, c2, y;
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y,
|
||||
(IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node))));
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironGet, c1, c2, y), y,
|
||||
(IsNewEnvironNode(c1.GetNode(node)) && IsVNode(c2.GetNode(node))));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
|
||||
// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}}
|
||||
class EnvGetItemAddEliminater : public AnfVisitor {
|
||||
// {prim::kPrimEnvironGet, {prim::kPrimEnvironAdd, X, Y}, C, Z} ->
|
||||
// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvironGet, X, C, Z}, {prim::kPrimEnvironGet, Y, C, Z}}
|
||||
class EnvironGetAddEliminater : public AnfVisitor {
|
||||
public:
|
||||
EnvGetItemAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
|
||||
~EnvGetItemAddEliminater() override = default;
|
||||
EnvironGetAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
|
||||
~EnvironGetAddEliminater() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
is_match_ = false;
|
||||
auto IsAddCNode = [](const AnfNodePtr &node) -> bool {
|
||||
return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast<CNodePtr>()->size() == 3;
|
||||
return IsPrimitiveCNode(node, prim::kPrimEnvironAdd) && node->cast<CNodePtr>()->size() == 3;
|
||||
};
|
||||
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimEnvironGet, {IsAddCNode, IsVNode, IsNode})(node);
|
||||
|
||||
if (!is_match_ || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {prim::kPrimEnvGetItem, {...}, C, Z}
|
||||
// {prim::kPrimEnvironGet, {...}, C, Z}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inp1 = cnode->input(1)->cast<CNodePtr>();
|
||||
auto c = cnode->input(2);
|
||||
auto z = cnode->input(3);
|
||||
|
||||
// {prim::kPrimEnvAdd, X, Y}
|
||||
// {prim::kPrimEnvironAdd, X, Y}
|
||||
auto x = inp1->input(1);
|
||||
auto y = inp1->input(2);
|
||||
|
||||
auto fg = node->func_graph();
|
||||
auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z});
|
||||
auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z});
|
||||
auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), x, c, z});
|
||||
auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), y, c, z});
|
||||
|
||||
return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz});
|
||||
}
|
||||
|
@ -209,17 +216,17 @@ class EnvGetItemAddEliminater : public AnfVisitor {
|
|||
ValuePtr PrimHyperAdd_;
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z}
|
||||
class EnvGetSetItemEliminater : public AnfVisitor {
|
||||
// {prim::kPrimEnvironGet, {prim::kPrimEnvironSet, X, C1, Y}, C2, Z}
|
||||
class EnvironGetSetEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
is_match_ = false;
|
||||
auto IsSetCNode = [](const AnfNodePtr &node) -> bool {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimEnvironSet)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// {prim::kPrimEnvSetItem, X, C1, Y}
|
||||
// {prim::kPrimEnvironSet, X, C1, Y}
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != 4) {
|
||||
return false;
|
||||
|
@ -227,21 +234,21 @@ class EnvGetSetItemEliminater : public AnfVisitor {
|
|||
|
||||
return IsValueNode<SymbolicKeyInstance>(inputs[2]);
|
||||
};
|
||||
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimEnvironGet, {IsSetCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
|
||||
if (!is_match_ || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {prim::kPrimEnvGetItem, {...}, C2, Z}
|
||||
// {prim::kPrimEnvironGet, {...}, C2, Z}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inp1 = cnode->input(1)->cast<CNodePtr>();
|
||||
auto key2 = cnode->input(2);
|
||||
auto c2 = GetValueNode<SymbolicKeyInstancePtr>(key2);
|
||||
auto default_v = cnode->input(3);
|
||||
|
||||
// {prim::kPrimEnvSetItem, X, C1, Y}
|
||||
auto env = inp1->input(1);
|
||||
// {prim::kPrimEnvironSet, X, C1, Y}
|
||||
AnfNodePtr environ_node = inp1->input(1);
|
||||
auto c1 = GetValueNode<SymbolicKeyInstancePtr>(inp1->input(2));
|
||||
auto last_set = inp1->input(3);
|
||||
|
||||
|
@ -249,27 +256,27 @@ class EnvGetSetItemEliminater : public AnfVisitor {
|
|||
return last_set;
|
||||
}
|
||||
|
||||
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
|
||||
// {prim::kPrimEnvSetItem, env, symbolickey, value}
|
||||
auto &inputs = env->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != 4) {
|
||||
while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) {
|
||||
// {prim::kPrimEnvironSet, environ, symbolickey, value}
|
||||
auto &inputs = environ_node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != internal::kEnvironGetSetInputSize) {
|
||||
MS_LOG(WARNING) << "Input size should be 4";
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
|
||||
if (!IsValueNode<SymbolicKeyInstance>(inputs[internal::kSymbolicKeyOffset])) {
|
||||
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
env = inputs[1];
|
||||
last_set = inputs[3];
|
||||
auto symbolic_c1 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
|
||||
environ_node = inputs[internal::kEnvironOffset];
|
||||
last_set = inputs[internal::kValueOffset];
|
||||
auto symbolic_c1 = GetValueNode<SymbolicKeyInstancePtr>(inputs[internal::kSymbolicKeyOffset]);
|
||||
if (*symbolic_c1 == *c2) {
|
||||
return last_set;
|
||||
}
|
||||
}
|
||||
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v});
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, key2, default_v});
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override { is_match_ = true; }
|
||||
|
@ -278,9 +285,9 @@ class EnvGetSetItemEliminater : public AnfVisitor {
|
|||
bool is_match_{false};
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetitem, {prim::kPrimDepend, X1, X2}, item, dflt} ->
|
||||
// {prim::kPrimDepend, {prim::kPrimEnvGetitem, X1, item, dflt}, X2}
|
||||
class EnvGetItemDependSwap : public OptimizerCaller {
|
||||
// {prim::kPrimEnvironGet, {prim::kPrimDepend, X1, X2}, item, dflt} ->
|
||||
// {prim::kPrimDepend, {prim::kPrimEnvironGet, X1, item, dflt}, X2}
|
||||
class EnvironGetDependSwap : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
|
@ -290,18 +297,30 @@ class EnvGetItemDependSwap : public OptimizerCaller {
|
|||
ScopeGuard scope_guard(scope);
|
||||
|
||||
PatternNode x1, x2, item, dflt;
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimDepend, x1, x2), item, dflt),
|
||||
PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimEnvGetItem, x1, item, dflt), x2));
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimEnvironGet, PPrimitive(prim::kPrimDepend, x1, x2), item, dflt),
|
||||
PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimEnvironGet, x1, item, dflt), x2));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
|
||||
class IncorporateEnvGetitem : public AnfVisitor {
|
||||
// {prim::kPrimEnvironAdd, C1, X} -> X
|
||||
// {prim::kPrimEnvironAdd, X, C1} -> X
|
||||
class EnvironAddConstEliminater : public AnfVisitor {
|
||||
public:
|
||||
explicit IncorporateEnvGetitem(bool bypass_recursive = false)
|
||||
: env_get_item_transform_(), bypass_recursive_(bypass_recursive) {}
|
||||
~IncorporateEnvGetitem() override = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode c1, x;
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironAdd, c1, x), x, (IsNewEnvironNode(c1.GetNode(node))));
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironAdd, x, c1), x, (IsNewEnvironNode(c1.GetNode(node))));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvironGet, {G, Xs}, C, Y}
|
||||
class IncorporateEnvironGet : public AnfVisitor {
|
||||
public:
|
||||
explicit IncorporateEnvironGet(bool bypass_recursive = false)
|
||||
: environ_get_transform_(), bypass_recursive_(bypass_recursive) {}
|
||||
~IncorporateEnvironGet() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
|
@ -316,13 +335,13 @@ class IncorporateEnvGetitem : public AnfVisitor {
|
|||
}
|
||||
return IsValueNode<FuncGraph>(cnode->input(0));
|
||||
};
|
||||
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimEnvironGet, {IsGCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
|
||||
if (!is_match_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {prim::kPrimEnvGetItem, {...}, C, Y}
|
||||
// {prim::kPrimEnvironGet, {...}, C, Y}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inp1 = cnode->input(1)->cast<CNodePtr>();
|
||||
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
|
||||
|
@ -331,9 +350,9 @@ class IncorporateEnvGetitem : public AnfVisitor {
|
|||
// {G, Xs}
|
||||
auto inputs = inp1->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
auto new_fg = env_get_item_transform_(fg, key, default_v);
|
||||
auto new_fg = environ_get_transform_(fg, key, default_v);
|
||||
if (fg->recursive() && bypass_recursive_) {
|
||||
MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString();
|
||||
MS_LOG(DEBUG) << "Bypass EnvironGet transform for recursive fg=" << fg->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
if (new_fg == nullptr) {
|
||||
|
@ -350,15 +369,15 @@ class IncorporateEnvGetitem : public AnfVisitor {
|
|||
|
||||
private:
|
||||
bool is_match_{false};
|
||||
internal::EnvGetitemTransform env_get_item_transform_;
|
||||
internal::EnvironGetTransform environ_get_transform_;
|
||||
bool bypass_recursive_;
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
|
||||
class IncorporateEnvGetitemSwitch : public AnfVisitor {
|
||||
// {prim::kPrimEnvironGet, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
|
||||
class IncorporateEnvironGetSwitch : public AnfVisitor {
|
||||
public:
|
||||
IncorporateEnvGetitemSwitch() : env_get_item_transform_() {}
|
||||
~IncorporateEnvGetitemSwitch() override = default;
|
||||
IncorporateEnvironGetSwitch() : environ_get_transform_() {}
|
||||
~IncorporateEnvironGetSwitch() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
|
@ -374,12 +393,12 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
|
|||
|
||||
return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch);
|
||||
};
|
||||
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimEnvironGet, {IsSwNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
if (!is_match_ || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {prim::kPrimEnvGetItem, {...}, C, Y}
|
||||
// {prim::kPrimEnvironGet, {...}, C, Y}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inp1 = cnode->input(1)->cast<CNodePtr>();
|
||||
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
|
||||
|
@ -398,8 +417,8 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
|
|||
auto x = sw->input(1);
|
||||
auto g1 = GetValueNode<FuncGraphPtr>(sw->input(2));
|
||||
auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
|
||||
auto new_g1 = env_get_item_transform_(g1, key, default_v);
|
||||
auto new_g2 = env_get_item_transform_(g2, key, default_v);
|
||||
auto new_g1 = environ_get_transform_(g1, key, default_v);
|
||||
auto new_g2 = environ_get_transform_(g2, key, default_v);
|
||||
if (new_g1 == nullptr || new_g2 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -416,14 +435,14 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
|
|||
|
||||
private:
|
||||
bool is_match_{false};
|
||||
internal::EnvGetitemTransform env_get_item_transform_;
|
||||
internal::EnvironGetTransform environ_get_transform_;
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y}
|
||||
class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
|
||||
// {prim::kPrimEnvironGet, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y}
|
||||
class IncorporateEnvironGetSwitchLayer : public AnfVisitor {
|
||||
public:
|
||||
IncorporateEnvGetitemSwitchLayer() : env_get_item_transform_() {}
|
||||
~IncorporateEnvGetitemSwitchLayer() override = default;
|
||||
IncorporateEnvironGetSwitchLayer() : environ_get_transform_() {}
|
||||
~IncorporateEnvironGetSwitchLayer() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
|
@ -431,11 +450,11 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
|
|||
return nullptr;
|
||||
}
|
||||
is_match_ = false;
|
||||
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimEnvironGet, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
if (!is_match_ || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// {prim::kPrimEnvGetItem, {...}, C, Y}
|
||||
// {prim::kPrimEnvironGet, {...}, C, Y}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inp1 = cnode->input(1)->cast<CNodePtr>();
|
||||
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
|
||||
|
@ -463,7 +482,8 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
|
|||
std::vector<FuncGraphPtr> graphs{};
|
||||
auto graphs_cnode = sw->input(2)->cast<CNodePtr>();
|
||||
auto &graphs_inputs = graphs_cnode->inputs();
|
||||
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= 2 &&
|
||||
const int kMinInputSize = 2;
|
||||
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= kMinInputSize &&
|
||||
IsValueNode<FuncGraph>(graphs_inputs[1])) {
|
||||
(void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs),
|
||||
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
|
||||
|
@ -475,7 +495,7 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
|
|||
auto fg = node->func_graph();
|
||||
std::vector<AnfNodePtr> layers;
|
||||
for (auto &graph : graphs) {
|
||||
auto fg_transform = env_get_item_transform_(graph, key, default_v);
|
||||
auto fg_transform = environ_get_transform_(graph, key, default_v);
|
||||
if (fg_transform == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -493,9 +513,9 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
|
|||
|
||||
private:
|
||||
bool is_match_{false};
|
||||
internal::EnvGetitemTransformACrossGraph env_get_item_transform_;
|
||||
internal::EnvironGetTransformACrossGraph environ_get_transform_;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_
|
|
@ -648,15 +648,15 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
MS_EXCEPTION_IF_NULL(switch_call);
|
||||
const auto &switch_call_cnode = switch_call->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_call_cnode);
|
||||
// If exist env_getitem/env_setitem in this funcgraph or
|
||||
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
|
||||
// If exist EnvironGet/EnvironSet in this funcgraph or
|
||||
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any EnvironGet or EnvironSet;
|
||||
std::vector<internal::TpCNodeAndIndex> tp_cnodes_and_index;
|
||||
auto switch_call_users_counter = MultipleUseOfSwitch(switch_call, fg, &tp_cnodes_and_index);
|
||||
bool multiple_use = (tp_cnodes_and_index.size() > 1);
|
||||
if (g1_output_is_shrinkable && g2_output_is_shrinkable && multiple_use &&
|
||||
(tp_cnodes_and_index.size() == switch_call_users_counter)) {
|
||||
if (!internal::HasMoreJ(optimizer) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
|
||||
!ExistEnvNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) {
|
||||
if (!internal::HasMoreJ(optimizer) && !ExistEnvironNode(fg) && !ExistEnvironNodeInTupleItem(g1_) &&
|
||||
!ExistEnvironNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) {
|
||||
MS_LOG(DEBUG) << "No more j, will shrink. Node: " << node->DebugString()
|
||||
<< ", switch: " << switch_->DebugString();
|
||||
const auto g1_output_size = internal::GetOutputSize(g1_->output());
|
||||
|
@ -834,15 +834,15 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
return node_users.size();
|
||||
}
|
||||
|
||||
static bool inline ExistEnvNode(const FuncGraphPtr &fg) {
|
||||
static bool inline ExistEnvironNode(const FuncGraphPtr &fg) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto &nodes = fg->value_nodes();
|
||||
return std::any_of(nodes.begin(), nodes.end(), [](const auto &node) {
|
||||
return IsPrimitive(node.first, prim::kPrimEnvSetItem) || IsPrimitive(node.first, prim::kPrimEnvGetItem);
|
||||
return IsPrimitive(node.first, prim::kPrimEnvironSet) || IsPrimitive(node.first, prim::kPrimEnvironGet);
|
||||
});
|
||||
}
|
||||
|
||||
static bool inline ExistEnvNodeInTupleItem(const FuncGraphPtr &fg) {
|
||||
static bool inline ExistEnvironNodeInTupleItem(const FuncGraphPtr &fg) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
const auto &output = fg->output();
|
||||
if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
|
@ -852,7 +852,7 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
const auto &inputs = cnode->inputs();
|
||||
return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &input) {
|
||||
auto sub_fg = GetValueNode<FuncGraphPtr>(input);
|
||||
if (sub_fg != nullptr && ExistEnvNode(sub_fg)) {
|
||||
if (sub_fg != nullptr && ExistEnvironNode(sub_fg)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -167,12 +167,12 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state) {
|
|||
return input_monad;
|
||||
}
|
||||
|
||||
bool ExistEnvGetItem(FuncGraphManagerPtr manager) {
|
||||
bool ExistEnvironGet(FuncGraphManagerPtr manager) {
|
||||
const FuncGraphSet &fgs = manager->func_graphs();
|
||||
for (auto &fg : fgs) {
|
||||
auto &nodes = fg->value_nodes();
|
||||
bool exist = std::any_of(nodes.begin(), nodes.end(),
|
||||
[](const auto &node) { return IsPrimitive(node.first, prim::kPrimEnvGetItem); });
|
||||
[](const auto &node) { return IsPrimitive(node.first, prim::kPrimEnvironGet); });
|
||||
if (exist) {
|
||||
return true;
|
||||
}
|
||||
|
@ -181,10 +181,10 @@ bool ExistEnvGetItem(FuncGraphManagerPtr manager) {
|
|||
}
|
||||
|
||||
// Convert:
|
||||
// cnode1 = env_setitem(EnvInstance, para1, attach1)
|
||||
// cnode2 = env_setitem(cnode1, para2, attach2)
|
||||
// cnode1 = EnvironSet(EnvCreate(), para1, attach1)
|
||||
// cnode2 = EnvironSet(cnode1, para2, attach2)
|
||||
// ...
|
||||
// cnode_n = env_setitem(cnode_n-1, para_n-1, attach_n-1)
|
||||
// cnode_n = EnvironSet(cnode_n-1, para_n-1, attach_n-1)
|
||||
// maketuple = maketuple(cnode_n, ...)
|
||||
// updatestate = updatestate(umonad, maketuple)
|
||||
// To:
|
||||
|
@ -194,26 +194,26 @@ AnfNodePtr EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr &update_st
|
|||
std::vector<AnfNodePtr> env_nodes;
|
||||
std::vector<AnfNodePtr> new_maketuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
size_t input_size = make_tuple->inputs().size();
|
||||
bool has_env_setitem = false;
|
||||
bool has_environ_set = false;
|
||||
for (size_t i = 1; i < input_size; i++) {
|
||||
auto node = make_tuple->input(i);
|
||||
if (IsPrimitiveCNode(node, prim::kPrimEnvSetItem) && OnlyUsedByOneNode(node, make_tuple)) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimEnvironSet) && OnlyUsedByOneNode(node, make_tuple)) {
|
||||
env_nodes.emplace_back(node);
|
||||
has_env_setitem = true;
|
||||
has_environ_set = true;
|
||||
} else if (node->isa<CNode>() && !IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
new_maketuple_inputs.emplace_back(node);
|
||||
}
|
||||
}
|
||||
if (!has_env_setitem) {
|
||||
if (!has_environ_set) {
|
||||
return nullptr;
|
||||
}
|
||||
// Check env_setitem in MakeTuple
|
||||
// Check EnvironSet in MakeTuple
|
||||
auto mgr = GetManager(update_state);
|
||||
if (mgr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// If exist env_getitem, don't eliminate env_setitem.
|
||||
if (ExistEnvGetItem(mgr)) {
|
||||
// If exist EnvironGet, don't eliminate EnvironSet.
|
||||
if (ExistEnvironGet(mgr)) {
|
||||
return nullptr;
|
||||
}
|
||||
const size_t first_index = 1;
|
||||
|
@ -228,7 +228,7 @@ AnfNodePtr EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr &update_st
|
|||
auto env_cnode = env->cast<CNodePtr>();
|
||||
auto env_input = env_cnode->input(first_index);
|
||||
auto attach = env_cnode->input(attach_index);
|
||||
if (IsPrimitiveCNode(env_input, prim::kPrimEnvSetItem) && OnlyUsedByOneNode(env_input, env_cnode)) {
|
||||
if (IsPrimitiveCNode(env_input, prim::kPrimEnvironSet) && OnlyUsedByOneNode(env_input, env_cnode)) {
|
||||
env_nodes.emplace_back(env_input);
|
||||
new_maketuple_inputs.insert(new_maketuple_inputs.begin() + no_env_node_size, attach);
|
||||
}
|
||||
|
|
|
@ -464,12 +464,11 @@ constexpr char TUPLE_DIV[] = "tuple_div";
|
|||
constexpr char TUPLE_TO_ARRAY[] = "tuple_to_array";
|
||||
constexpr char VIRTUALLOSS[] = "VirtualLoss";
|
||||
constexpr char RETURN[] = "Return";
|
||||
constexpr char ENV_GETITEM[] = "env_getitem";
|
||||
constexpr char ENVIRONGET[] = "EnvironGet";
|
||||
constexpr char IDENTITY[] = "identity";
|
||||
constexpr char PARTIAL[] = "partial";
|
||||
constexpr char ENVSETITEM[] = "env_setitem";
|
||||
constexpr char ENVGETITEM[] = "env_getitem";
|
||||
constexpr char ENVADD[] = "env_add";
|
||||
constexpr char ENVIRONSET[] = "EnvironSet";
|
||||
constexpr char ENVIRONADD[] = "EnvironAdd";
|
||||
constexpr char MAKEREFKEY[] = "MakeRefKey";
|
||||
constexpr char MAKEREF[] = "make_ref";
|
||||
constexpr char GETREFKEY[] = "get_ref_key";
|
||||
|
|
|
@ -2862,7 +2862,7 @@ static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) {
|
|||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimEnvironGet)) {
|
||||
return FindGrad(node->cast<CNodePtr>(), ++curr_depth);
|
||||
} else {
|
||||
return node;
|
||||
|
|
|
@ -45,7 +45,6 @@
|
|||
|
||||
namespace py = pybind11;
|
||||
|
||||
using EnvInstance = mindspore::EnvInstance;
|
||||
using GraphExecutorPy = mindspore::pipeline::GraphExecutorPy;
|
||||
using Pipeline = mindspore::pipeline::Pipeline;
|
||||
using PrimitivePy = mindspore::PrimitivePy;
|
||||
|
@ -118,8 +117,6 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_jit_config", &GraphExecutorPy::SetJitConfig, py::arg("jit_config") = py::dict(), "Set the jit config.")
|
||||
.def("generate_arguments_key", &GraphExecutorPy::GenerateArgumentsKey, "Generate unique key of argument.");
|
||||
|
||||
(void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_").def(py::init());
|
||||
|
||||
(void)m.def("real_run_op", &mindspore::pynative::RealRunOp, "Run op pynatively.");
|
||||
(void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id");
|
||||
(void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl");
|
||||
|
|
|
@ -502,7 +502,6 @@ std::vector<DataConverterPtr> GetDataConverters() {
|
|||
std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
|
||||
std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>),
|
||||
std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>),
|
||||
std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
|
||||
std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
|
||||
[](const py::object &obj) -> ValuePtr {
|
||||
auto res =
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||
#include "frontend/optimizer/recompute.h"
|
||||
#include "frontend/optimizer/slice_activation_in_recompute.h"
|
||||
#include "frontend/optimizer/environ_conversion.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
|
@ -198,12 +199,12 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
|
|||
(void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
|
||||
opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
|
||||
(void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
|
||||
opt::OptPassConfig env_eliminate = opt::OptPassConfig({
|
||||
opt::OptPassConfig environ_eliminate = opt::OptPassConfig({
|
||||
irpass.incorporate_call_,
|
||||
irpass.incorporate_call_switch_,
|
||||
irpass.incorporate_getitem_set_,
|
||||
});
|
||||
(void)map.emplace_back(std::make_pair("env_eliminate", env_eliminate));
|
||||
(void)map.emplace_back(std::make_pair("environ_eliminate", environ_eliminate));
|
||||
}
|
||||
|
||||
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", res, map);
|
||||
|
@ -264,10 +265,11 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.tuple_list_get_item_depend_reorder_,
|
||||
irpass.tuple_list_convert_item_index_to_positive_,
|
||||
|
||||
irpass.env_get_item_eliminate_,
|
||||
irpass.env_get_item_add_eliminate_,
|
||||
irpass.env_get_set_item_eliminate_,
|
||||
irpass.env_get_item_depend_swap_,
|
||||
irpass.environ_get_eliminate_,
|
||||
irpass.environ_get_add_eliminate_,
|
||||
irpass.environ_get_set_eliminate_,
|
||||
irpass.environ_get_depend_swap_,
|
||||
irpass.environ_add_const_eliminate_,
|
||||
|
||||
irpass.cast_eliminate_,
|
||||
irpass.reshape_eliminate_,
|
||||
|
@ -301,16 +303,16 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.specialize_transform_,
|
||||
irpass.merge_addn_,
|
||||
irpass.float_tuple_getitem_switch_,
|
||||
irpass.float_env_getitem_switch_,
|
||||
irpass.float_environ_get_switch_,
|
||||
irpass.inline_,
|
||||
irpass.updatestate_useless_node_eliminater_,
|
||||
irpass.tuple_list_get_item_eliminator_,
|
||||
irpass.incorporate_getitem_set_,
|
||||
irpass.incorporate_call_,
|
||||
irpass.incorporate_call_switch_,
|
||||
irpass.incorporate_env_getitem_bypass_recursive_,
|
||||
irpass.incorporate_env_getitem_switch_,
|
||||
irpass.env_get_item_eliminate_,
|
||||
irpass.incorporate_environ_get_bypass_recursive_,
|
||||
irpass.incorporate_environ_get_switch_,
|
||||
irpass.environ_get_eliminate_,
|
||||
irpass.depend_value_elim_,
|
||||
irpass.all_reduce_const_elim_,
|
||||
},
|
||||
|
@ -434,13 +436,14 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.stopgrad_eliminater_,
|
||||
irpass.special_op_eliminate_,
|
||||
irpass.get_make_ref_eliminate_,
|
||||
irpass.incorporate_env_getitem_,
|
||||
irpass.incorporate_env_getitem_switch_,
|
||||
irpass.env_get_item_eliminate_,
|
||||
irpass.env_get_item_add_eliminate_,
|
||||
irpass.env_get_set_item_eliminate_,
|
||||
irpass.env_get_item_depend_swap_,
|
||||
irpass.incorporate_env_getitem_switch_layer_,
|
||||
irpass.incorporate_environ_get_,
|
||||
irpass.incorporate_environ_get_switch_,
|
||||
irpass.environ_get_eliminate_,
|
||||
irpass.environ_get_add_eliminate_,
|
||||
irpass.environ_get_set_eliminate_,
|
||||
irpass.environ_get_depend_swap_,
|
||||
irpass.environ_add_const_eliminate_,
|
||||
irpass.incorporate_environ_get_switch_layer_,
|
||||
irpass.value_based_eliminate_,
|
||||
irpass.virtual_accu_grad_,
|
||||
irpass.virtual_assign_add_,
|
||||
|
@ -732,6 +735,15 @@ bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EnvironConversionPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
if (enable_closure) {
|
||||
opt::EnvironConversion(res);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_before_recompute", OptBeforeRecomputeGroup},
|
||||
{"opt_a", OptPassAGroup},
|
||||
|
@ -744,6 +756,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
{"add_cache_embedding", AddCacheEmbeddingPass},
|
||||
{"add_recomputation", AddRecomputationPass},
|
||||
{"cse_after_recomputation", OptAfterRecomputeGroup},
|
||||
{"environ_conv", EnvironConversionPass},
|
||||
{"slice_recompute_activation", SliceRecomputeActivationPass}};
|
||||
|
||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
|
|
|
@ -123,12 +123,14 @@ AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplEnvironCreate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -43,7 +43,15 @@ AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return args_spec_list[0];
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplEnvironCreate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: None.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 0);
|
||||
static AbstractBasePtr abs_env = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
return abs_env;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
|
||||
|
@ -53,7 +61,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
TypePtr type = key->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kObjectTypeSymbolicKeyType) {
|
||||
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
|
||||
MS_LOG(EXCEPTION) << "EnvironGet evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
|
||||
}
|
||||
|
||||
auto context = MsContext::GetInstance();
|
||||
|
@ -76,7 +84,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
return expected;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 3);
|
||||
|
@ -86,7 +94,7 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
MS_EXCEPTION_IF_NULL(key_value_ptr);
|
||||
auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
|
||||
if (key_value_track == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: "
|
||||
MS_LOG(EXCEPTION) << "EnvironSet evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: "
|
||||
<< key_value_ptr->ToString();
|
||||
}
|
||||
auto expected = key_value_track->abstract();
|
||||
|
@ -94,8 +102,8 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
|
|
|
@ -200,9 +200,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimLoad, R{InferImplLoad, nullptr, true}},
|
||||
// Set impl to null as it will use PartialEvaluator;
|
||||
{prim::kPrimPartial, R{nullptr, nullptr, true}},
|
||||
{prim::kPrimEnvGetItem, R{InferImplEnvGetItem, nullptr, true}},
|
||||
{prim::kPrimEnvSetItem, R{InferImplEnvSetItem, nullptr, true}},
|
||||
{prim::kPrimEnvAdd, R{InferImplEnvAdd, nullptr, true}},
|
||||
{prim::kPrimEnvironCreate, R{InferImplEnvironCreate, nullptr, true}},
|
||||
{prim::kPrimEnvironGet, R{InferImplEnvironGet, nullptr, true}},
|
||||
{prim::kPrimEnvironSet, R{InferImplEnvironSet, nullptr, true}},
|
||||
{prim::kPrimEnvironAdd, R{InferImplEnvironAdd, nullptr, true}},
|
||||
{prim::kPrimMakeRefKey, R{InferImplMakeRefKey, nullptr, true}},
|
||||
{prim::kPrimMakeRef, R{InferImplMakeRef, nullptr, true}},
|
||||
{prim::kPrimGetRefKey, R{InferImplGetRefKey, nullptr, true}},
|
||||
|
|
|
@ -120,6 +120,13 @@ constexpr auto kSparseApplyAdadelta = "SparseApplyAdadelta";
|
|||
constexpr auto kRoll = "Roll";
|
||||
constexpr auto kTanh = "Tanh";
|
||||
|
||||
// Others
|
||||
constexpr auto kEnvironCreate = "EnvironCreate";
|
||||
constexpr auto kEnvironSet = "EnvironSet";
|
||||
constexpr auto kEnvironGet = "EnvironGet";
|
||||
constexpr auto kEnvironAdd = "EnvironAdd";
|
||||
|
||||
//
|
||||
// Here list all primitives used in backend or some special primitives used by core.
|
||||
// GetNext
|
||||
inline const PrimitivePtr kPrimGetNext = std::make_shared<Primitive>(kGetNext);
|
||||
|
@ -708,9 +715,10 @@ inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_ap
|
|||
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
|
||||
|
||||
// Other miscellaneous
|
||||
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
|
||||
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
|
||||
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
|
||||
inline const PrimitivePtr kPrimEnvironCreate = std::make_shared<Primitive>(kEnvironCreate);
|
||||
inline const PrimitivePtr kPrimEnvironSet = std::make_shared<Primitive>(kEnvironSet);
|
||||
inline const PrimitivePtr kPrimEnvironGet = std::make_shared<Primitive>(kEnvironGet);
|
||||
inline const PrimitivePtr kPrimEnvironAdd = std::make_shared<Primitive>(kEnvironAdd);
|
||||
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
|
||||
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
|
||||
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
|
|
|
@ -26,7 +26,9 @@
|
|||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
class SymbolicKeyInstance : public Value {
|
||||
|
@ -97,37 +99,15 @@ struct SymbolicKeyInstanceEqual {
|
|||
}
|
||||
};
|
||||
|
||||
using EnvInstanceContentsMap =
|
||||
mindspore::HashMap<SymbolicKeyInstancePtr, Any, SymbolicKeyInstanceHash, SymbolicKeyInstanceEqual>;
|
||||
static inline AnfNodePtr NewEnviron(const FuncGraphPtr &fg) {
|
||||
return fg->NewCNode({NewValueNode(prim::kPrimEnvironCreate)});
|
||||
}
|
||||
|
||||
// Environment mapping keys to values.
|
||||
// Keys are SymbolicKeyInstances, which represent nodes in the graph along
|
||||
// with inferred properties.
|
||||
class EnvInstance : public Value {
|
||||
public:
|
||||
friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &env);
|
||||
static inline bool IsNewEnvironNode(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimEnvironCreate); }
|
||||
|
||||
EnvInstance() = default;
|
||||
~EnvInstance() override = default;
|
||||
MS_DECLARE_PARENT(EnvInstance, Value);
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<EnvInstance>(), std::make_shared<EnvType>());
|
||||
}
|
||||
bool operator==(const EnvInstance &other) const;
|
||||
bool operator==(const Value &other) const override;
|
||||
EnvInstance(const EnvInstance &v) : Value(v) {}
|
||||
EnvInstance(EnvInstance &&v) = default;
|
||||
EnvInstance &operator=(EnvInstance &&src) noexcept { return *this; };
|
||||
|
||||
std::size_t hash() const override {
|
||||
// deterministic characteristic of member variables.
|
||||
return tid();
|
||||
}
|
||||
};
|
||||
|
||||
using EnvInstancePtr = std::shared_ptr<EnvInstance>;
|
||||
|
||||
extern std::shared_ptr<EnvInstance> newenv;
|
||||
static inline abstract::AbstractBasePtr MakeEnvironAbstract() {
|
||||
return std::make_shared<abstract::AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_UTILS_SYMBOLIC_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
"""Data type for MindSpore."""
|
||||
|
||||
import numpy as np
|
||||
from .._c_expression import typing, EnvInstance_
|
||||
from .._c_expression import typing
|
||||
from .._c_expression.typing import Type
|
||||
|
||||
__dtype__ = [
|
||||
|
@ -161,7 +161,6 @@ _simple_types = {
|
|||
np.float16: float16,
|
||||
np.float32: float32,
|
||||
np.float64: float64,
|
||||
EnvInstance_: env_type,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from functools import partial
|
|||
from types import FunctionType
|
||||
|
||||
from mindspore import context
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_executor, _wrap_func
|
||||
|
@ -29,7 +29,7 @@ from ..operations import _grad_ops
|
|||
from .. import operations as P
|
||||
from .. import signature as sig
|
||||
|
||||
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
|
||||
|
||||
def add_flags(fn=None, **flags):
|
||||
|
@ -852,7 +852,7 @@ zip_operation = _ZipOperation('zip_operation')
|
|||
env_get = MultitypeFuncGraph("env_get")
|
||||
|
||||
|
||||
env_getitem = Primitive('env_getitem')
|
||||
environ_get = Primitive('EnvironGet')
|
||||
ref_to_embed = _grad_ops.RefToEmbed()
|
||||
zeros_like = P.ZerosLike()
|
||||
|
||||
|
@ -860,4 +860,4 @@ zeros_like = P.ZerosLike()
|
|||
@env_get.register("EnvType", "Tensor")
|
||||
def _tensor_env_get(env, parameter):
|
||||
"""Used to get env."""
|
||||
return env_getitem(env, ref_to_embed(parameter), zeros_like(parameter))
|
||||
return environ_get(env, ref_to_embed(parameter), zeros_like(parameter))
|
||||
|
|
|
@ -254,7 +254,7 @@ def _add_env(x, y):
|
|||
Returns:
|
||||
EnvType, equal to x + y.
|
||||
"""
|
||||
return F.env_add(x, y)
|
||||
return F.environ_add(x, y)
|
||||
|
||||
|
||||
@add.register("Tuple", "Tuple")
|
||||
|
|
|
@ -58,9 +58,6 @@ def _ones_like_sparse_tensor(x):
|
|||
return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x))
|
||||
|
||||
|
||||
newenv = base.EnvInstance_()
|
||||
|
||||
|
||||
@ones_like_leaf.register("Function")
|
||||
def _ones_like_func(x):
|
||||
"""
|
||||
|
@ -70,10 +67,10 @@ def _ones_like_func(x):
|
|||
x (Function): x
|
||||
|
||||
Returns:
|
||||
EnvInstance_, value is newenv.
|
||||
A instance of EnvType.
|
||||
"""
|
||||
# Unused parameters are placeholders.
|
||||
return newenv
|
||||
return F.environ_create()
|
||||
|
||||
|
||||
ones_like = base.HyperMap(ones_like_leaf)
|
||||
|
|
|
@ -37,9 +37,6 @@ def _zeros_like_bool(x):
|
|||
return False
|
||||
|
||||
|
||||
newenv = base.EnvInstance_()
|
||||
|
||||
|
||||
@zeros_like_leaf.register("Function")
|
||||
def _zeros_like_func(x):
|
||||
"""
|
||||
|
@ -49,10 +46,10 @@ def _zeros_like_func(x):
|
|||
x (Function): x
|
||||
|
||||
Returns:
|
||||
EnvInstance_, value is newenv.
|
||||
A instance of EnvType.
|
||||
"""
|
||||
# Unused parameters are placeholders.
|
||||
return newenv
|
||||
return F.environ_create()
|
||||
|
||||
|
||||
@zeros_like_leaf.register("Tensor")
|
||||
|
|
|
@ -456,9 +456,10 @@ zeros_like = P.ZerosLike()
|
|||
distribute = Primitive('distribute')
|
||||
embed = Primitive('embed')
|
||||
ref_to_embed = _grad_ops.RefToEmbed()
|
||||
env_setitem = Primitive('env_setitem')
|
||||
env_getitem = Primitive('env_getitem')
|
||||
env_add = Primitive('env_add')
|
||||
environ_create = Primitive('EnvironCreate')
|
||||
environ_set = Primitive('EnvironSet')
|
||||
environ_get = Primitive('EnrironGet')
|
||||
environ_add = Primitive('EnvironAdd')
|
||||
J = Primitive('J')
|
||||
SliceGetItem = Primitive("SliceGetItem")
|
||||
switch = Primitive('Switch')
|
||||
|
|
|
@ -2047,7 +2047,7 @@ class RefToEmbed(Primitive):
|
|||
Make a key from Ref.
|
||||
|
||||
The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
|
||||
and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref.
|
||||
and get items by operation `EnvironGet` with the symbolic_key instance. The `Parameter` is a ref.
|
||||
|
||||
Inputs:
|
||||
- **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
|
||||
|
|
|
@ -346,19 +346,28 @@ TEST_F(TestOps, EmbedTest) {
|
|||
ASSERT_EQ(prim->name(), kPrimEmbed->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, EnvSetItemTest) {
|
||||
auto prim = std::make_shared<Primitive>("env_setitem");
|
||||
ASSERT_EQ(prim->name(), kPrimEnvSetItem->name());
|
||||
/// Feature: Check primitive name equivalence
|
||||
/// Description: EnvironSet primitive name equivalence
|
||||
/// Expectation: Equal
|
||||
TEST_F(TestOps, EnvironSetTest) {
|
||||
auto prim = std::make_shared<Primitive>("EnvironSet");
|
||||
ASSERT_EQ(prim->name(), kPrimEnvironSet->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, EnvGetItemTest) {
|
||||
auto prim = std::make_shared<Primitive>("env_getitem");
|
||||
ASSERT_EQ(prim->name(), kPrimEnvGetItem->name());
|
||||
/// Feature: Check primitive name equivalence
|
||||
/// Description: EnvironGet primitive name equivalence
|
||||
/// Expectation: Equal
|
||||
TEST_F(TestOps, EnvironGetTest) {
|
||||
auto prim = std::make_shared<Primitive>("EnvironGet");
|
||||
ASSERT_EQ(prim->name(), kPrimEnvironGet->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, EnvAddest) {
|
||||
auto prim = std::make_shared<Primitive>("env_add");
|
||||
ASSERT_EQ(prim->name(), kPrimEnvAdd->name());
|
||||
/// Feature: Check primitive name equivalence
|
||||
/// Description: EnvironAdd primitive name equivalence
|
||||
/// Expectation: Equal
|
||||
TEST_F(TestOps, EnvironAddTest) {
|
||||
auto prim = std::make_shared<Primitive>("EnvironAdd");
|
||||
ASSERT_EQ(prim->name(), kPrimEnvironAdd->name());
|
||||
}
|
||||
|
||||
// Neural Network
|
||||
|
|
|
@ -370,81 +370,81 @@ TEST_F(TestPrim, test_partial) {
|
|||
ASSERT_TRUE(res->ToString() == expected->ToString());
|
||||
}
|
||||
|
||||
// def test_env(x, y):
|
||||
// return env_setitem(newenv, embed(x), y)
|
||||
TEST_F(TestPrim, test_env_setitem) {
|
||||
/// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
|
||||
/// Description: Check inference result of primitive EnvironSet.
|
||||
/// Expectation: Equal
|
||||
TEST_F(TestPrim, test_environ_set) {
|
||||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvironSet, 3);
|
||||
|
||||
AbstractBasePtr abstract_env = ToAbstract(newenv);
|
||||
AbstractBasePtr abstract_environ = MakeEnvironAbstract();
|
||||
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
|
||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||
args_spec_list = {abstract_environ, embed_x, abstract_y};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
ASSERT_TRUE(*res == *exp);
|
||||
}
|
||||
|
||||
// def test_env(x, y, z):
|
||||
// e = env_setitem(newenv, embed(x), y)
|
||||
// return env_getitem(e, embed(x), z)
|
||||
TEST_F(TestPrim, test_env_getitem) {
|
||||
/// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
|
||||
/// Description: Check inference result of primitive EnvironGet.
|
||||
/// Expectation: Equal
|
||||
TEST_F(TestPrim, test_environ_get) {
|
||||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||
|
||||
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||
FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3);
|
||||
|
||||
AbstractBasePtr abstract_env = ToAbstract(newenv);
|
||||
AbstractBasePtr abstract_environ = MakeEnvironAbstract();
|
||||
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
|
||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||
args_spec_list = {abstract_environ, embed_x, abstract_y};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr res = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
ASSERT_TRUE(*res == *exp);
|
||||
|
||||
FuncGraphPtr graph_getitem = MakeFuncGraph(prim::kPrimEnvGetItem, 3);
|
||||
FuncGraphPtr graph_environ_get = MakeFuncGraph(prim::kPrimEnvironGet, 3);
|
||||
|
||||
AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
|
||||
args_spec_list = {res, embed_x, abstract_z};
|
||||
|
||||
res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract();
|
||||
res = engine_->Run(graph_environ_get, args_spec_list).inferred->abstract();
|
||||
|
||||
ASSERT_TRUE(*res == *abstract_x);
|
||||
}
|
||||
|
||||
// def test_env(x, y, z):
|
||||
// e1 = env_setitem(newenv, embed(x), y)
|
||||
// e2 = env_setitem(newenv, embed(x), z)
|
||||
// return env_add(e1, e2)
|
||||
TEST_F(TestPrim, test_env_add) {
|
||||
/// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
|
||||
/// Description: Check inference result of primitive EnvironAdd.
|
||||
/// Expectation: Equal
|
||||
TEST_F(TestPrim, test_environ_add) {
|
||||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||
|
||||
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||
FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3);
|
||||
|
||||
AbstractBasePtr abstract_env = ToAbstract(newenv);
|
||||
AbstractBasePtr abstract_environ = MakeEnvironAbstract();
|
||||
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
|
||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||
args_spec_list = {abstract_environ, embed_x, abstract_y};
|
||||
|
||||
AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr abstract_e1 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
ASSERT_TRUE(*abstract_e1 == *exp);
|
||||
|
||||
AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
|
||||
args_spec_list = {abstract_env, embed_x, abstract_z};
|
||||
args_spec_list = {abstract_environ, embed_x, abstract_z};
|
||||
|
||||
AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr abstract_e2 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*abstract_e2 == *exp);
|
||||
|
||||
FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
|
||||
FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvironAdd, 2);
|
||||
args_spec_list = {abstract_e1, abstract_e2};
|
||||
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();
|
||||
|
||||
|
|
|
@ -315,7 +315,7 @@ def test_while_if():
|
|||
|
||||
|
||||
# should compile success without zeros_like_tensor args mismatch, the generated graph files
|
||||
# should not contain env_getitem or env_setitem.
|
||||
# should not contain EnvironGet or EnvironSet.
|
||||
# InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set,
|
||||
# so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like
|
||||
# will have U as parameter, but the call site zeros_like(fv) don't have U argument.
|
||||
|
@ -347,7 +347,7 @@ def test_grad_fv_and_insert_gradient_of():
|
|||
|
||||
net = FvAndInsertGradientNet()
|
||||
input_data = Tensor(np.array([1.0], np.float32))
|
||||
# if use grad_all_list, the generated graph will have env_setitem
|
||||
# if use grad_all_list, the generated graph will have EnvironSet
|
||||
# as gradient for inputs is constant zero, so it will depend on result of grad.
|
||||
grad_net = grad_by_list(net, ParameterTuple(net.trainable_params()))
|
||||
print(grad_net(input_data))
|
||||
|
|
Loading…
Reference in New Issue