rename env_getitem/setitem to EnvironGet/EnvironGet

This commit is contained in:
zhousiyi 2021-12-31 01:14:21 +00:00
parent 617965c751
commit 19e9571149
33 changed files with 523 additions and 314 deletions

View File

@ -587,7 +587,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
std::vector<AnfNodePtr> grads;
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
grads.push_back(NewValueNode(newenv));
grads.push_back(NewEnviron(b));
for (int64_t i = 0; i < tuple_size; ++i) {
grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
}
@ -628,7 +628,7 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args
std::vector<AnfNodePtr> grads;
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
grads.push_back(NewValueNode(newenv));
grads.push_back(NewEnviron(b));
for (int64_t i = 0; i < list_size; ++i) {
grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
}

View File

@ -128,7 +128,7 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
fv_adjoint->second->RegisterKUser(default_val_node, 1);
anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
}
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvironGet), din, embed_node, default_val_node});
MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
@ -421,7 +421,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
fv_adjoint->second->RegisterKUser(node, 1);
auto sens = fv_adjoint->second->dout();
new_grad_fv = tape_->NewCNode({
NewValueNode(prim::kPrimEnvSetItem),
NewValueNode(prim::kPrimEnvironSet),
new_grad_fv,
node,
sens,
@ -448,7 +448,7 @@ AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
fv_adjoint.second->RegisterKUser(node, 1);
auto sens = fv_adjoint.second->dout();
new_grad_fv = tape_->NewCNode({
NewValueNode(prim::kPrimEnvSetItem),
NewValueNode(prim::kPrimEnvironSet),
new_grad_fv,
node,
sens,
@ -486,9 +486,9 @@ void DFunctor::MapMorphism() {
// Set output for tape closure.
AnfNodePtr grad_fv;
if (lift_fv_before_grad) {
grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
grad_fv = AttachFvDoutToTape(NewEnviron(tape_));
} else {
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewEnviron(tape_)));
}
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};

View File

@ -115,7 +115,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_indirect_fv_;
// Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvGetItemTransform in optimizer
// Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvironGetTransform in optimizer
// can hit its cache if fv_node is same.
mindspore::HashMap<AnfNodePtr, std::pair<CNodePtr, CNodePtr>> anfnode_to_envitem_;
FuncGraphPtr primal_graph_;

View File

@ -611,7 +611,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));
args.push_back(NewValueNode(newenv));
args.push_back(NewEnviron(bprop_fg));
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
if (!extra_args.empty()) {
args.insert(args.end(), extra_args.cbegin(), extra_args.cend());
@ -622,7 +622,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
// Set bprop output as (env, dx)
std::string model_name("mindspore.ops.composite.multitype_ops.add_impl");
std::string python_ops("_tuple_add");
auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg);
auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewEnviron(bprop_fg)}, bprop_fg);
auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
if (!extra_args.empty()) {
extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple));

View File

@ -0,0 +1,186 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/environ_conversion.h"
#include <memory>
#include <string>
#include <utility>
#include "utils/hash_map.h"
#include "abstract/abstract_function.h"
#include "utils/flags.h"
#include "utils/utils.h"
#include "utils/anf_utils.h"
#include "utils/symbolic.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
using SymbolicKeyConversionMap =
mindspore::HashMap<SymbolicKeyInstancePtr, int64_t, SymbolicKeyInstanceHash, SymbolicKeyInstanceEqual>;
namespace {
bool IsAbstractEnvType(const abstract::AbstractBasePtr &abs) {
if (abs != nullptr && abs->isa<abstract::AbstractScalar>() && abs->GetTypeTrack()->isa<EnvType>()) {
return true;
}
return false;
}
abstract::AbstractBasePtr TransformAbstractRecursively(const abstract::AbstractBasePtr &orig_abs,
const abstract::AbstractBasePtr target_abs) {
if (orig_abs == nullptr) {
return nullptr;
}
if (IsAbstractEnvType(orig_abs)) {
return target_abs;
}
if (!orig_abs->isa<abstract::AbstractSequence>()) {
return nullptr;
}
const auto &abs_seq = orig_abs->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(abs_seq);
const auto &elements = abs_seq->elements();
abstract::AbstractBasePtrList new_elements;
bool transformed = false;
for (auto elem : elements) {
if (elem->isa<abstract::AbstractSequence>()) {
auto inner_abs_seq = elem->cast<abstract::AbstractSequencePtr>();
auto transformed_abs = TransformAbstractRecursively(inner_abs_seq, target_abs);
if (transformed_abs != nullptr) {
new_elements.push_back(transformed_abs);
transformed = true;
} else {
new_elements.push_back(elem);
}
} else if (IsAbstractEnvType(elem)) {
new_elements.push_back(target_abs);
transformed = true;
} else {
new_elements.push_back(elem);
}
}
if (transformed) {
abstract::AbstractBasePtr new_abs;
if (abs_seq->isa<abstract::AbstractTuple>()) {
new_abs = std::make_shared<abstract::AbstractTuple>(new_elements);
} else if (abs_seq->isa<abstract::AbstractList>()) {
new_abs = std::make_shared<abstract::AbstractList>(new_elements);
} else {
MS_LOG(EXCEPTION) << "abs_seq is not AbstractTuple or AbstractList, but: " << abs_seq->ToString();
}
return new_abs;
}
// No transformation.
return nullptr;
}
void TransformNodeAbstractIfEnvType(const AnfNodePtr &node, const abstract::AbstractBasePtr &target_abs) {
auto transformed_abs = TransformAbstractRecursively(node->abstract(), target_abs);
if (transformed_abs != nullptr) {
node->set_abstract(transformed_abs);
}
}
TypeId GetValueType(const CNodePtr &cnode) {
// (EnvironSet/EnvironGet, environ, key, value/default)
if (cnode->inputs().size() != 4) {
MS_LOG(EXCEPTION) << "EnvrionSet/EnvironGet cnode should have 4 inputs, but: " << cnode->DebugString();
}
const auto &value_abstract = cnode->input(3)->abstract();
if (value_abstract == nullptr) {
MS_LOG(EXCEPTION) << "4th input of EnvironSet/EnvironGet cnode should abstract, but not set, node: "
<< cnode->DebugString();
}
if (IsAbstractEnvType(value_abstract)) {
return kObjectTypeEnvType;
} else {
return kObjectTypeTensorType;
}
}
AnfNodePtr GetTransformedKeyNode(const AnfNodePtr &old_key_node, SymbolicKeyConversionMap *para_symbol_map) {
const auto &symbolic_key_inst = GetValueNode<SymbolicKeyInstancePtr>(old_key_node);
uint64_t transformed_key = 0;
auto &symbolic_key_map = *para_symbol_map;
auto iter = symbolic_key_map.find(symbolic_key_inst);
if (iter != symbolic_key_map.end()) {
transformed_key = iter->second;
} else {
static uint64_t key_counter = 0;
transformed_key = ++key_counter;
symbolic_key_map.emplace(std::make_pair(symbolic_key_inst, transformed_key));
}
auto tensor_key = std::make_shared<mindspore::tensor::Tensor>(transformed_key);
auto transformed_key_node = NewValueNode(tensor_key);
transformed_key_node->set_abstract(tensor_key->ToAbstract());
return transformed_key_node;
}
} // namespace
bool EnvironConversion(const pipeline::ResourcePtr &resource) {
SymbolicKeyConversionMap symbolic_key_map;
static AbstractBasePtr scalar_abs = std::make_shared<abstract::AbstractScalar>(kAnyValue, kInt64);
static AbstractBasePtr tensor_abs = std::make_shared<abstract::AbstractTensor>(scalar_abs);
static std::string attr_name = "value_type";
const int kPrimitiveOffset = 0;
const int kSymbolicKeyOffset = 2;
auto mng = resource->manager();
const auto &all_nodes = mng->all_nodes();
auto txn = mng->Transact();
for (const auto &node : all_nodes) {
TransformNodeAbstractIfEnvType(node, tensor_abs);
if (!IsPrimitiveCNode(node, prim::kPrimEnvironSet) && !IsPrimitiveCNode(node, prim::kPrimEnvironGet)) {
continue;
}
const auto &cnode = node->cast<CNodePtr>();
// Prim
AnfNodePtr transformed_prim_node;
const auto &type_id = GetValueType(cnode);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(kPrimitiveOffset));
MS_EXCEPTION_IF_NULL(prim);
const auto &old_attr = prim->GetAttr(attr_name);
if (old_attr != nullptr) {
const auto old_attr_value = GetValue<int>(old_attr);
if (old_attr_value != type_id) {
if (IsPrimitiveCNode(node, prim::kPrimEnvironSet)) {
prim = std::make_shared<Primitive>(prim::kEnvironSet);
} else {
prim = std::make_shared<Primitive>(prim::kEnvironGet);
}
MS_EXCEPTION_IF_NULL(prim);
prim->set_attr(attr_name, MakeValue(static_cast<int>(type_id)));
transformed_prim_node = NewValueNode(prim);
txn.SetEdge(node, kPrimitiveOffset, transformed_prim_node);
}
} else {
prim->set_attr(attr_name, MakeValue(static_cast<int>(type_id)));
}
// Abstract of Environ & Value will be set by previous TransformNodeAbstract function.
// Key
if (!IsValueNode<SymbolicKeyInstance>(cnode->input(kSymbolicKeyOffset))) {
MS_LOG(EXCEPTION) << "should be SymbolicKey, but: " << cnode->input(kSymbolicKeyOffset)->ToString();
}
const auto &transformed_key_node = GetTransformedKeyNode(cnode->input(kSymbolicKeyOffset), &symbolic_key_map);
txn.SetEdge(node, kSymbolicKeyOffset, transformed_key_node);
}
txn.Commit();
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,6 +1,4 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -16,24 +14,16 @@
* limitations under the License.
*/
#include "utils/symbolic.h"
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_
#include <memory>
#include "pipeline/jit/resource.h"
namespace mindspore {
std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &objPtr) {
out << "()";
return out;
}
bool EnvInstance::operator==(const EnvInstance &other) const { return true; }
bool EnvInstance::operator==(const Value &other) const {
if (other.isa<EnvInstance>()) {
auto other_env_inst = static_cast<const EnvInstance *>(&other);
return *this == *other_env_inst;
}
return false;
}
std::shared_ptr<EnvInstance> newenv = std::make_shared<EnvInstance>();
/* namespace to support opt */
namespace opt {
bool EnvironConversion(const pipeline::ResourcePtr &resource);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_

View File

@ -19,7 +19,7 @@
#include "frontend/optimizer/irpass/branch_culling.h"
#include "frontend/optimizer/irpass/cast_eliminate.h"
#include "frontend/optimizer/irpass/convert.h"
#include "frontend/optimizer/irpass/env_item_eliminate.h"
#include "frontend/optimizer/irpass/environ_eliminate.h"
#include "frontend/optimizer/irpass/grad_var_prepare.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/inline.h"
@ -119,26 +119,28 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);
real_op_eliminate_ = MakeSubstitution(std::make_shared<RealOpEliminate>(), "real_op_eliminate", prim::kPrimRealInner);
// Env Item Eliminate
env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
env_get_item_add_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemAddEliminater>(), "env_get_item_add_eliminate_", prim::kPrimEnvGetItem);
env_get_set_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetSetItemEliminater>(), "env_get_set_item_eliminate", prim::kPrimEnvGetItem);
env_get_item_depend_swap_ =
MakeSubstitution(std::make_shared<EnvGetItemDependSwap>(), "env_get_item_depend_swap", prim::kPrimEnvGetItem);
// Environ Item Eliminate
environ_get_eliminate_ =
MakeSubstitution(std::make_shared<EnvironGetEliminater>(), "environ_get_eliminate", prim::kPrimEnvironGet);
environ_get_add_eliminate_ =
MakeSubstitution(std::make_shared<EnvironGetAddEliminater>(), "environ_get_add_eliminate_", prim::kPrimEnvironGet);
environ_get_set_eliminate_ =
MakeSubstitution(std::make_shared<EnvironGetSetEliminater>(), "environ_get_set_eliminate", prim::kPrimEnvironGet);
environ_get_depend_swap_ =
MakeSubstitution(std::make_shared<EnvironGetDependSwap>(), "environ_get_depend_swap", prim::kPrimEnvironGet);
environ_add_const_eliminate_ = MakeSubstitution(std::make_shared<EnvironAddConstEliminater>(),
"environ_add_const_eliminate_", prim::kPrimEnvironAdd);
incorporate_env_getitem_bypass_recursive_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_environ_get_bypass_recursive_ =
MakeSubstitution(std::make_shared<IncorporateEnvironGet>(true), "incorporate_environ_get", prim::kPrimEnvironGet);
incorporate_environ_get_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvironGetSwitch>(),
"incorporate_environ_get_switch", prim::kPrimEnvironGet);
incorporate_environ_get_ =
MakeSubstitution(std::make_shared<IncorporateEnvironGet>(), "incorporate_environ_get", prim::kPrimEnvironGet);
incorporate_env_getitem_switch_layer_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitchLayer>(), "incorporate_env_getitem_switch_layer",
prim::kPrimEnvGetItem);
incorporate_environ_get_switch_layer_ =
MakeSubstitution(std::make_shared<IncorporateEnvironGetSwitchLayer>(), "incorporate_environ_get_switch_layer",
prim::kPrimEnvironGet);
// Ref eliminate
make_ref_eliminate_ =
@ -157,8 +159,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
"float_tuple_getitem_switch", prim::kPrimTupleGetItem);
float_env_getitem_switch_ =
MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
float_environ_get_switch_ =
MakeSubstitution(std::make_shared<FloatEnvironGetSwitch>(), "float_environ_get_switch", prim::kPrimEnvironGet);
exchange_switch_depend_value_ =
MakeSubstitution(std::make_shared<ExchangeSwitchDependValue>(), "exchange_switch_depend_value", prim::kPrimSwitch);

View File

@ -65,14 +65,15 @@ class OptimizeIRPassLib {
SubstitutionPtr real_op_eliminate_;
// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;
SubstitutionPtr env_get_item_add_eliminate_;
SubstitutionPtr env_get_set_item_eliminate_;
SubstitutionPtr env_get_item_depend_swap_;
SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
SubstitutionPtr incorporate_env_getitem_switch_;
SubstitutionPtr incorporate_env_getitem_switch_layer_;
SubstitutionPtr environ_get_eliminate_;
SubstitutionPtr environ_get_add_eliminate_;
SubstitutionPtr environ_get_set_eliminate_;
SubstitutionPtr environ_get_depend_swap_;
SubstitutionPtr environ_add_const_eliminate_;
SubstitutionPtr incorporate_environ_get_;
SubstitutionPtr incorporate_environ_get_bypass_recursive_;
SubstitutionPtr incorporate_environ_get_switch_;
SubstitutionPtr incorporate_environ_get_switch_layer_;
// Ref eliminate
SubstitutionPtr make_ref_eliminate_;
@ -84,7 +85,7 @@ class OptimizeIRPassLib {
// Branch culling
SubstitutionPtr switch_simplify_;
SubstitutionPtr float_tuple_getitem_switch_;
SubstitutionPtr float_env_getitem_switch_;
SubstitutionPtr float_environ_get_switch_;
SubstitutionPtr exchange_switch_depend_value_;
SubstitutionPtr switch_partial_eliminater_;

View File

@ -58,8 +58,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
{prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}},
{prim::kPrimTupleGetItem, {2}},
{prim::kPrimEnvGetItem, {1}},
{prim::kPrimEnvSetItem, {1}},
{prim::kPrimEnvironGet, {1}},
{prim::kPrimEnvironSet, {1}},
{prim::kPrimReduceSum, {2}},
{prim::kPrimReduceMean, {2}},
{prim::kPrimReduceAll, {2}},
@ -83,7 +83,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
{prim::kPrimEnvironGet, {1}}, {prim::kPrimEnvironSet, {1}},
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},

View File

@ -66,16 +66,16 @@ class FloatTupleGetItemSwitch : public OptimizerCaller {
}
};
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
class FloatEnvGetItemSwitch : public OptimizerCaller {
// {prim::kPrimEnvironGet, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
// {prim::kPrimSwitch, X1, {prim::kPrimEnvironGet, X2, X4, X5}, {prim::kPrimEnvironGet, X3, X4, X5}}
class FloatEnvironGetSwitch : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
MATCH_REPLACE(node,
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)));
PPrimitive(prim::kPrimEnvironGet, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvironGet, true_br, x, x2),
PPrimitive(prim::kPrimEnvironGet, false_br, x, x2)));
return nullptr;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_
#include <algorithm>
#include <memory>
@ -37,10 +37,15 @@ namespace mindspore {
namespace opt {
namespace irpass {
namespace internal {
class EnvGetitemTransform {
constexpr int kEnvironGetSetInputSize = 4;
constexpr int kEnvironOffset = 1;
constexpr int kSymbolicKeyOffset = 2;
constexpr int kValueOffset = 3;
class EnvironGetTransform {
public:
EnvGetitemTransform() : cache_() {}
~EnvGetitemTransform() = default;
EnvironGetTransform() : cache_() {}
~EnvironGetTransform() = default;
FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
if (cache_.find(fg) == cache_.end()) {
@ -56,22 +61,22 @@ class EnvGetitemTransform {
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
auto env = new_fg->output();
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4) {
auto environ_node = new_fg->output();
while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) {
// {prim::kPrimEnvironSet, environ, symbolickey, value}
auto &inputs = environ_node->cast<CNodePtr>()->inputs();
if (inputs.size() != kEnvironGetSetInputSize) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
if (!IsValueNode<SymbolicKeyInstance>(inputs[kSymbolicKeyOffset])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
auto value = inputs[3];
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
environ_node = inputs[kEnvironOffset];
auto value = inputs[kValueOffset];
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[kSymbolicKeyOffset]);
if (*key2 == *key) {
new_fg->set_output(value);
cache[hash_key] = new_fg;
@ -79,7 +84,8 @@ class EnvGetitemTransform {
return new_fg;
}
}
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
new_fg->set_output(
new_fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, NewValueNode(key), default_node}));
cache[hash_key] = new_fg;
}
@ -92,10 +98,10 @@ class EnvGetitemTransform {
cache_;
};
class EnvGetitemTransformACrossGraph {
class EnvironGetTransformACrossGraph {
public:
EnvGetitemTransformACrossGraph() : cache_() {}
~EnvGetitemTransformACrossGraph() = default;
EnvironGetTransformACrossGraph() : cache_() {}
~EnvironGetTransformACrossGraph() = default;
FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
if (cache_.find(fg) == cache_.end()) {
@ -120,29 +126,30 @@ class EnvGetitemTransformACrossGraph {
auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str()));
new_fg_outer->set_output(NewValueNode(new_fg));
auto env = new_fg->output();
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4) {
auto environ_node = new_fg->output();
while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) {
// {prim::kPrimEnvironSet, environ, symbolickey, value}
auto &inputs = environ_node->cast<CNodePtr>()->inputs();
if (inputs.size() != kEnvironGetSetInputSize) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
if (!IsValueNode<SymbolicKeyInstance>(inputs[kSymbolicKeyOffset])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
auto value = inputs[3];
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
environ_node = inputs[kEnvironOffset];
auto value = inputs[kValueOffset];
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[kSymbolicKeyOffset]);
if (*key2 == *key) {
new_fg->set_output(value);
cache[hash_key] = new_fg_outer;
return new_fg_outer;
}
}
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
new_fg->set_output(
new_fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, NewValueNode(key), default_node}));
cache[hash_key] = new_fg_outer;
}
@ -156,48 +163,48 @@ class EnvGetitemTransformACrossGraph {
};
} // namespace internal
// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
class EnvGetItemEliminater : public AnfVisitor {
// {prim::kPrimEnvironGet, C1, C2, Y} -> Y
class EnvironGetEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode c1, c2, y;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y,
(IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node))));
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironGet, c1, c2, y), y,
(IsNewEnvironNode(c1.GetNode(node)) && IsVNode(c2.GetNode(node))));
return nullptr;
}
};
// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}}
class EnvGetItemAddEliminater : public AnfVisitor {
// {prim::kPrimEnvironGet, {prim::kPrimEnvironAdd, X, Y}, C, Z} ->
// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvironGet, X, C, Z}, {prim::kPrimEnvironGet, Y, C, Z}}
class EnvironGetAddEliminater : public AnfVisitor {
public:
EnvGetItemAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
~EnvGetItemAddEliminater() override = default;
EnvironGetAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
~EnvironGetAddEliminater() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false;
auto IsAddCNode = [](const AnfNodePtr &node) -> bool {
return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast<CNodePtr>()->size() == 3;
return IsPrimitiveCNode(node, prim::kPrimEnvironAdd) && node->cast<CNodePtr>()->size() == 3;
};
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimEnvironGet, {IsAddCNode, IsVNode, IsNode})(node);
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}
// {prim::kPrimEnvGetItem, {...}, C, Z}
// {prim::kPrimEnvironGet, {...}, C, Z}
auto cnode = node->cast<CNodePtr>();
auto inp1 = cnode->input(1)->cast<CNodePtr>();
auto c = cnode->input(2);
auto z = cnode->input(3);
// {prim::kPrimEnvAdd, X, Y}
// {prim::kPrimEnvironAdd, X, Y}
auto x = inp1->input(1);
auto y = inp1->input(2);
auto fg = node->func_graph();
auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z});
auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z});
auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), x, c, z});
auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), y, c, z});
return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz});
}
@ -209,17 +216,17 @@ class EnvGetItemAddEliminater : public AnfVisitor {
ValuePtr PrimHyperAdd_;
};
// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z}
class EnvGetSetItemEliminater : public AnfVisitor {
// {prim::kPrimEnvironGet, {prim::kPrimEnvironSet, X, C1, Y}, C2, Z}
class EnvironGetSetEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false;
auto IsSetCNode = [](const AnfNodePtr &node) -> bool {
if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) {
if (!IsPrimitiveCNode(node, prim::kPrimEnvironSet)) {
return false;
}
// {prim::kPrimEnvSetItem, X, C1, Y}
// {prim::kPrimEnvironSet, X, C1, Y}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() != 4) {
return false;
@ -227,21 +234,21 @@ class EnvGetSetItemEliminater : public AnfVisitor {
return IsValueNode<SymbolicKeyInstance>(inputs[2]);
};
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
AnfVisitor::Match(prim::kPrimEnvironGet, {IsSetCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}
// {prim::kPrimEnvGetItem, {...}, C2, Z}
// {prim::kPrimEnvironGet, {...}, C2, Z}
auto cnode = node->cast<CNodePtr>();
auto inp1 = cnode->input(1)->cast<CNodePtr>();
auto key2 = cnode->input(2);
auto c2 = GetValueNode<SymbolicKeyInstancePtr>(key2);
auto default_v = cnode->input(3);
// {prim::kPrimEnvSetItem, X, C1, Y}
auto env = inp1->input(1);
// {prim::kPrimEnvironSet, X, C1, Y}
AnfNodePtr environ_node = inp1->input(1);
auto c1 = GetValueNode<SymbolicKeyInstancePtr>(inp1->input(2));
auto last_set = inp1->input(3);
@ -249,27 +256,27 @@ class EnvGetSetItemEliminater : public AnfVisitor {
return last_set;
}
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4) {
while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) {
// {prim::kPrimEnvironSet, environ, symbolickey, value}
auto &inputs = environ_node->cast<CNodePtr>()->inputs();
if (inputs.size() != internal::kEnvironGetSetInputSize) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
if (!IsValueNode<SymbolicKeyInstance>(inputs[internal::kSymbolicKeyOffset])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
last_set = inputs[3];
auto symbolic_c1 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
environ_node = inputs[internal::kEnvironOffset];
last_set = inputs[internal::kValueOffset];
auto symbolic_c1 = GetValueNode<SymbolicKeyInstancePtr>(inputs[internal::kSymbolicKeyOffset]);
if (*symbolic_c1 == *c2) {
return last_set;
}
}
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v});
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, key2, default_v});
}
void Visit(const AnfNodePtr &) override { is_match_ = true; }
@ -278,9 +285,9 @@ class EnvGetSetItemEliminater : public AnfVisitor {
bool is_match_{false};
};
// {prim::kPrimEnvGetitem, {prim::kPrimDepend, X1, X2}, item, dflt} ->
// {prim::kPrimDepend, {prim::kPrimEnvGetitem, X1, item, dflt}, X2}
class EnvGetItemDependSwap : public OptimizerCaller {
// {prim::kPrimEnvironGet, {prim::kPrimDepend, X1, X2}, item, dflt} ->
// {prim::kPrimDepend, {prim::kPrimEnvironGet, X1, item, dflt}, X2}
class EnvironGetDependSwap : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
@ -290,18 +297,30 @@ class EnvGetItemDependSwap : public OptimizerCaller {
ScopeGuard scope_guard(scope);
PatternNode x1, x2, item, dflt;
MATCH_REPLACE(node, PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimDepend, x1, x2), item, dflt),
PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimEnvGetItem, x1, item, dflt), x2));
MATCH_REPLACE(node, PPrimitive(prim::kPrimEnvironGet, PPrimitive(prim::kPrimDepend, x1, x2), item, dflt),
PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimEnvironGet, x1, item, dflt), x2));
return nullptr;
}
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class IncorporateEnvGetitem : public AnfVisitor {
// {prim::kPrimEnvironAdd, C1, X} -> X
// {prim::kPrimEnvironAdd, X, C1} -> X
class EnvironAddConstEliminater : public AnfVisitor {
public:
explicit IncorporateEnvGetitem(bool bypass_recursive = false)
: env_get_item_transform_(), bypass_recursive_(bypass_recursive) {}
~IncorporateEnvGetitem() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode c1, x;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironAdd, c1, x), x, (IsNewEnvironNode(c1.GetNode(node))));
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironAdd, x, c1), x, (IsNewEnvironNode(c1.GetNode(node))));
return nullptr;
}
};
// {prim::kPrimEnvironGet, {G, Xs}, C, Y}
class IncorporateEnvironGet : public AnfVisitor {
public:
explicit IncorporateEnvironGet(bool bypass_recursive = false)
: environ_get_transform_(), bypass_recursive_(bypass_recursive) {}
~IncorporateEnvironGet() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
@ -316,13 +335,13 @@ class IncorporateEnvGetitem : public AnfVisitor {
}
return IsValueNode<FuncGraph>(cnode->input(0));
};
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
AnfVisitor::Match(prim::kPrimEnvironGet, {IsGCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
if (!is_match_) {
return nullptr;
}
// {prim::kPrimEnvGetItem, {...}, C, Y}
// {prim::kPrimEnvironGet, {...}, C, Y}
auto cnode = node->cast<CNodePtr>();
auto inp1 = cnode->input(1)->cast<CNodePtr>();
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
@ -331,9 +350,9 @@ class IncorporateEnvGetitem : public AnfVisitor {
// {G, Xs}
auto inputs = inp1->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
auto new_fg = env_get_item_transform_(fg, key, default_v);
auto new_fg = environ_get_transform_(fg, key, default_v);
if (fg->recursive() && bypass_recursive_) {
MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString();
MS_LOG(DEBUG) << "Bypass EnvironGet transform for recursive fg=" << fg->ToString();
return nullptr;
}
if (new_fg == nullptr) {
@ -350,15 +369,15 @@ class IncorporateEnvGetitem : public AnfVisitor {
private:
bool is_match_{false};
internal::EnvGetitemTransform env_get_item_transform_;
internal::EnvironGetTransform environ_get_transform_;
bool bypass_recursive_;
};
// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
class IncorporateEnvGetitemSwitch : public AnfVisitor {
// {prim::kPrimEnvironGet, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
class IncorporateEnvironGetSwitch : public AnfVisitor {
public:
IncorporateEnvGetitemSwitch() : env_get_item_transform_() {}
~IncorporateEnvGetitemSwitch() override = default;
IncorporateEnvironGetSwitch() : environ_get_transform_() {}
~IncorporateEnvironGetSwitch() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
@ -374,12 +393,12 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch);
};
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
AnfVisitor::Match(prim::kPrimEnvironGet, {IsSwNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}
// {prim::kPrimEnvGetItem, {...}, C, Y}
// {prim::kPrimEnvironGet, {...}, C, Y}
auto cnode = node->cast<CNodePtr>();
auto inp1 = cnode->input(1)->cast<CNodePtr>();
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
@ -398,8 +417,8 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
auto x = sw->input(1);
auto g1 = GetValueNode<FuncGraphPtr>(sw->input(2));
auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
auto new_g1 = env_get_item_transform_(g1, key, default_v);
auto new_g2 = env_get_item_transform_(g2, key, default_v);
auto new_g1 = environ_get_transform_(g1, key, default_v);
auto new_g2 = environ_get_transform_(g2, key, default_v);
if (new_g1 == nullptr || new_g2 == nullptr) {
return nullptr;
}
@ -416,14 +435,14 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
private:
bool is_match_{false};
internal::EnvGetitemTransform env_get_item_transform_;
internal::EnvironGetTransform environ_get_transform_;
};
// {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y}
class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
// {prim::kPrimEnvironGet, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y}
class IncorporateEnvironGetSwitchLayer : public AnfVisitor {
public:
IncorporateEnvGetitemSwitchLayer() : env_get_item_transform_() {}
~IncorporateEnvGetitemSwitchLayer() override = default;
IncorporateEnvironGetSwitchLayer() : environ_get_transform_() {}
~IncorporateEnvironGetSwitchLayer() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
@ -431,11 +450,11 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
return nullptr;
}
is_match_ = false;
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
AnfVisitor::Match(prim::kPrimEnvironGet, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}
// {prim::kPrimEnvGetItem, {...}, C, Y}
// {prim::kPrimEnvironGet, {...}, C, Y}
auto cnode = node->cast<CNodePtr>();
auto inp1 = cnode->input(1)->cast<CNodePtr>();
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
@ -463,7 +482,8 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
std::vector<FuncGraphPtr> graphs{};
auto graphs_cnode = sw->input(2)->cast<CNodePtr>();
auto &graphs_inputs = graphs_cnode->inputs();
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= 2 &&
const int kMinInputSize = 2;
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= kMinInputSize &&
IsValueNode<FuncGraph>(graphs_inputs[1])) {
(void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs),
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
@ -475,7 +495,7 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
auto fg = node->func_graph();
std::vector<AnfNodePtr> layers;
for (auto &graph : graphs) {
auto fg_transform = env_get_item_transform_(graph, key, default_v);
auto fg_transform = environ_get_transform_(graph, key, default_v);
if (fg_transform == nullptr) {
return nullptr;
}
@ -493,9 +513,9 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
private:
bool is_match_{false};
internal::EnvGetitemTransformACrossGraph env_get_item_transform_;
internal::EnvironGetTransformACrossGraph environ_get_transform_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_

View File

@ -648,15 +648,15 @@ class IncorporateGetitemSwitch : public AnfVisitor {
MS_EXCEPTION_IF_NULL(switch_call);
const auto &switch_call_cnode = switch_call->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_call_cnode);
// If exist env_getitem/env_setitem in this funcgraph or
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
// If exist EnvironGet/EnvironSet in this funcgraph or
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any EnvironGet or EnvironSet;
std::vector<internal::TpCNodeAndIndex> tp_cnodes_and_index;
auto switch_call_users_counter = MultipleUseOfSwitch(switch_call, fg, &tp_cnodes_and_index);
bool multiple_use = (tp_cnodes_and_index.size() > 1);
if (g1_output_is_shrinkable && g2_output_is_shrinkable && multiple_use &&
(tp_cnodes_and_index.size() == switch_call_users_counter)) {
if (!internal::HasMoreJ(optimizer) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
!ExistEnvNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) {
if (!internal::HasMoreJ(optimizer) && !ExistEnvironNode(fg) && !ExistEnvironNodeInTupleItem(g1_) &&
!ExistEnvironNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) {
MS_LOG(DEBUG) << "No more j, will shrink. Node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
const auto g1_output_size = internal::GetOutputSize(g1_->output());
@ -834,15 +834,15 @@ class IncorporateGetitemSwitch : public AnfVisitor {
return node_users.size();
}
static bool inline ExistEnvNode(const FuncGraphPtr &fg) {
static bool inline ExistEnvironNode(const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(fg);
auto &nodes = fg->value_nodes();
return std::any_of(nodes.begin(), nodes.end(), [](const auto &node) {
return IsPrimitive(node.first, prim::kPrimEnvSetItem) || IsPrimitive(node.first, prim::kPrimEnvGetItem);
return IsPrimitive(node.first, prim::kPrimEnvironSet) || IsPrimitive(node.first, prim::kPrimEnvironGet);
});
}
static bool inline ExistEnvNodeInTupleItem(const FuncGraphPtr &fg) {
static bool inline ExistEnvironNodeInTupleItem(const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(fg);
const auto &output = fg->output();
if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
@ -852,7 +852,7 @@ class IncorporateGetitemSwitch : public AnfVisitor {
const auto &inputs = cnode->inputs();
return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &input) {
auto sub_fg = GetValueNode<FuncGraphPtr>(input);
if (sub_fg != nullptr && ExistEnvNode(sub_fg)) {
if (sub_fg != nullptr && ExistEnvironNode(sub_fg)) {
return true;
}
return false;

View File

@ -167,12 +167,12 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state) {
return input_monad;
}
bool ExistEnvGetItem(FuncGraphManagerPtr manager) {
bool ExistEnvironGet(FuncGraphManagerPtr manager) {
const FuncGraphSet &fgs = manager->func_graphs();
for (auto &fg : fgs) {
auto &nodes = fg->value_nodes();
bool exist = std::any_of(nodes.begin(), nodes.end(),
[](const auto &node) { return IsPrimitive(node.first, prim::kPrimEnvGetItem); });
[](const auto &node) { return IsPrimitive(node.first, prim::kPrimEnvironGet); });
if (exist) {
return true;
}
@ -181,10 +181,10 @@ bool ExistEnvGetItem(FuncGraphManagerPtr manager) {
}
// Convert:
// cnode1 = env_setitem(EnvInstance, para1, attach1)
// cnode2 = env_setitem(cnode1, para2, attach2)
// cnode1 = EnvironSet(EnvCreate(), para1, attach1)
// cnode2 = EnvironSet(cnode1, para2, attach2)
// ...
// cnode_n = env_setitem(cnode_n-1, para_n-1, attach_n-1)
// cnode_n = EnvironSet(cnode_n-1, para_n-1, attach_n-1)
// maketuple = maketuple(cnode_n, ...)
// updatestate = updatestate(umonad, maketuple)
// To:
@ -194,26 +194,26 @@ AnfNodePtr EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr &update_st
std::vector<AnfNodePtr> env_nodes;
std::vector<AnfNodePtr> new_maketuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
size_t input_size = make_tuple->inputs().size();
bool has_env_setitem = false;
bool has_environ_set = false;
for (size_t i = 1; i < input_size; i++) {
auto node = make_tuple->input(i);
if (IsPrimitiveCNode(node, prim::kPrimEnvSetItem) && OnlyUsedByOneNode(node, make_tuple)) {
if (IsPrimitiveCNode(node, prim::kPrimEnvironSet) && OnlyUsedByOneNode(node, make_tuple)) {
env_nodes.emplace_back(node);
has_env_setitem = true;
has_environ_set = true;
} else if (node->isa<CNode>() && !IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
new_maketuple_inputs.emplace_back(node);
}
}
if (!has_env_setitem) {
if (!has_environ_set) {
return nullptr;
}
// Check env_setitem in MakeTuple
// Check EnvironSet in MakeTuple
auto mgr = GetManager(update_state);
if (mgr == nullptr) {
return nullptr;
}
// If exist env_getitem, don't eliminate env_setitem.
if (ExistEnvGetItem(mgr)) {
// If exist EnvironGet, don't eliminate EnvironSet.
if (ExistEnvironGet(mgr)) {
return nullptr;
}
const size_t first_index = 1;
@ -228,7 +228,7 @@ AnfNodePtr EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr &update_st
auto env_cnode = env->cast<CNodePtr>();
auto env_input = env_cnode->input(first_index);
auto attach = env_cnode->input(attach_index);
if (IsPrimitiveCNode(env_input, prim::kPrimEnvSetItem) && OnlyUsedByOneNode(env_input, env_cnode)) {
if (IsPrimitiveCNode(env_input, prim::kPrimEnvironSet) && OnlyUsedByOneNode(env_input, env_cnode)) {
env_nodes.emplace_back(env_input);
new_maketuple_inputs.insert(new_maketuple_inputs.begin() + no_env_node_size, attach);
}

View File

@ -464,12 +464,11 @@ constexpr char TUPLE_DIV[] = "tuple_div";
constexpr char TUPLE_TO_ARRAY[] = "tuple_to_array";
constexpr char VIRTUALLOSS[] = "VirtualLoss";
constexpr char RETURN[] = "Return";
constexpr char ENV_GETITEM[] = "env_getitem";
constexpr char ENVIRONGET[] = "EnvironGet";
constexpr char IDENTITY[] = "identity";
constexpr char PARTIAL[] = "partial";
constexpr char ENVSETITEM[] = "env_setitem";
constexpr char ENVGETITEM[] = "env_getitem";
constexpr char ENVADD[] = "env_add";
constexpr char ENVIRONSET[] = "EnvironSet";
constexpr char ENVIRONADD[] = "EnvironAdd";
constexpr char MAKEREFKEY[] = "MakeRefKey";
constexpr char MAKEREF[] = "make_ref";
constexpr char GETREFKEY[] = "get_ref_key";

View File

@ -2862,7 +2862,7 @@ static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) {
if (!node->isa<CNode>()) {
continue;
}
if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) {
if (!IsPrimitiveCNode(node, prim::kPrimEnvironGet)) {
return FindGrad(node->cast<CNodePtr>(), ++curr_depth);
} else {
return node;

View File

@ -45,7 +45,6 @@
namespace py = pybind11;
using EnvInstance = mindspore::EnvInstance;
using GraphExecutorPy = mindspore::pipeline::GraphExecutorPy;
using Pipeline = mindspore::pipeline::Pipeline;
using PrimitivePy = mindspore::PrimitivePy;
@ -118,8 +117,6 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_jit_config", &GraphExecutorPy::SetJitConfig, py::arg("jit_config") = py::dict(), "Set the jit config.")
.def("generate_arguments_key", &GraphExecutorPy::GenerateArgumentsKey, "Generate unique key of argument.");
(void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_").def(py::init());
(void)m.def("real_run_op", &mindspore::pynative::RealRunOp, "Run op pynatively.");
(void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id");
(void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl");

View File

@ -502,7 +502,6 @@ std::vector<DataConverterPtr> GetDataConverters() {
std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>),
std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>),
std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
[](const py::object &obj) -> ValuePtr {
auto res =

View File

@ -42,6 +42,7 @@
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h"
#include "frontend/optimizer/slice_activation_in_recompute.h"
#include "frontend/optimizer/environ_conversion.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
#include "pipeline/pynative/pynative_execute.h"
@ -198,12 +199,12 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
(void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
(void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
opt::OptPassConfig env_eliminate = opt::OptPassConfig({
opt::OptPassConfig environ_eliminate = opt::OptPassConfig({
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_getitem_set_,
});
(void)map.emplace_back(std::make_pair("env_eliminate", env_eliminate));
(void)map.emplace_back(std::make_pair("environ_eliminate", environ_eliminate));
}
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", res, map);
@ -264,10 +265,11 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.tuple_list_get_item_depend_reorder_,
irpass.tuple_list_convert_item_index_to_positive_,
irpass.env_get_item_eliminate_,
irpass.env_get_item_add_eliminate_,
irpass.env_get_set_item_eliminate_,
irpass.env_get_item_depend_swap_,
irpass.environ_get_eliminate_,
irpass.environ_get_add_eliminate_,
irpass.environ_get_set_eliminate_,
irpass.environ_get_depend_swap_,
irpass.environ_add_const_eliminate_,
irpass.cast_eliminate_,
irpass.reshape_eliminate_,
@ -301,16 +303,16 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.specialize_transform_,
irpass.merge_addn_,
irpass.float_tuple_getitem_switch_,
irpass.float_env_getitem_switch_,
irpass.float_environ_get_switch_,
irpass.inline_,
irpass.updatestate_useless_node_eliminater_,
irpass.tuple_list_get_item_eliminator_,
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_,
irpass.env_get_item_eliminate_,
irpass.incorporate_environ_get_bypass_recursive_,
irpass.incorporate_environ_get_switch_,
irpass.environ_get_eliminate_,
irpass.depend_value_elim_,
irpass.all_reduce_const_elim_,
},
@ -434,13 +436,14 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.stopgrad_eliminater_,
irpass.special_op_eliminate_,
irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_switch_,
irpass.env_get_item_eliminate_,
irpass.env_get_item_add_eliminate_,
irpass.env_get_set_item_eliminate_,
irpass.env_get_item_depend_swap_,
irpass.incorporate_env_getitem_switch_layer_,
irpass.incorporate_environ_get_,
irpass.incorporate_environ_get_switch_,
irpass.environ_get_eliminate_,
irpass.environ_get_add_eliminate_,
irpass.environ_get_set_eliminate_,
irpass.environ_get_depend_swap_,
irpass.environ_add_const_eliminate_,
irpass.incorporate_environ_get_switch_layer_,
irpass.value_based_eliminate_,
irpass.virtual_accu_grad_,
irpass.virtual_assign_add_,
@ -732,6 +735,15 @@ bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
return true;
}
bool EnvironConversionPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
if (enable_closure) {
opt::EnvironConversion(res);
}
return true;
}
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_before_recompute", OptBeforeRecomputeGroup},
{"opt_a", OptPassAGroup},
@ -744,6 +756,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"add_cache_embedding", AddCacheEmbeddingPass},
{"add_recomputation", AddRecomputationPass},
{"cse_after_recomputation", OptAfterRecomputeGroup},
{"environ_conv", EnvironConversionPass},
{"slice_recompute_activation", SliceRecomputeActivationPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},

View File

@ -123,12 +123,14 @@ AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplEnvironCreate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -43,7 +43,15 @@ AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr
return args_spec_list[0];
}
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplEnvironCreate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: None.
CheckArgsSize(primitive->name(), args_spec_list, 0);
static AbstractBasePtr abs_env = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
return abs_env;
}
AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
// args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
@ -53,7 +61,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
TypePtr type = key->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() != kObjectTypeSymbolicKeyType) {
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
MS_LOG(EXCEPTION) << "EnvironGet evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
}
auto context = MsContext::GetInstance();
@ -76,7 +84,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
return expected;
}
AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
CheckArgsSize(primitive->name(), args_spec_list, 3);
@ -86,7 +94,7 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(key_value_ptr);
auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
if (key_value_track == nullptr) {
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: "
MS_LOG(EXCEPTION) << "EnvironSet evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: "
<< key_value_ptr->ToString();
}
auto expected = key_value_track->abstract();
@ -94,8 +102,8 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
}
AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
CheckArgsSize(primitive->name(), args_spec_list, 2);
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());

View File

@ -200,9 +200,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimLoad, R{InferImplLoad, nullptr, true}},
// Set impl to null as it will use PartialEvaluator;
{prim::kPrimPartial, R{nullptr, nullptr, true}},
{prim::kPrimEnvGetItem, R{InferImplEnvGetItem, nullptr, true}},
{prim::kPrimEnvSetItem, R{InferImplEnvSetItem, nullptr, true}},
{prim::kPrimEnvAdd, R{InferImplEnvAdd, nullptr, true}},
{prim::kPrimEnvironCreate, R{InferImplEnvironCreate, nullptr, true}},
{prim::kPrimEnvironGet, R{InferImplEnvironGet, nullptr, true}},
{prim::kPrimEnvironSet, R{InferImplEnvironSet, nullptr, true}},
{prim::kPrimEnvironAdd, R{InferImplEnvironAdd, nullptr, true}},
{prim::kPrimMakeRefKey, R{InferImplMakeRefKey, nullptr, true}},
{prim::kPrimMakeRef, R{InferImplMakeRef, nullptr, true}},
{prim::kPrimGetRefKey, R{InferImplGetRefKey, nullptr, true}},

View File

@ -120,6 +120,13 @@ constexpr auto kSparseApplyAdadelta = "SparseApplyAdadelta";
constexpr auto kRoll = "Roll";
constexpr auto kTanh = "Tanh";
// Others
constexpr auto kEnvironCreate = "EnvironCreate";
constexpr auto kEnvironSet = "EnvironSet";
constexpr auto kEnvironGet = "EnvironGet";
constexpr auto kEnvironAdd = "EnvironAdd";
//
// Here list all primitives used in backend or some special primitives used by core.
// GetNext
inline const PrimitivePtr kPrimGetNext = std::make_shared<Primitive>(kGetNext);
@ -708,9 +715,10 @@ inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_ap
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
// Other miscellaneous
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline const PrimitivePtr kPrimEnvironCreate = std::make_shared<Primitive>(kEnvironCreate);
inline const PrimitivePtr kPrimEnvironSet = std::make_shared<Primitive>(kEnvironSet);
inline const PrimitivePtr kPrimEnvironGet = std::make_shared<Primitive>(kEnvironGet);
inline const PrimitivePtr kPrimEnvironAdd = std::make_shared<Primitive>(kEnvironAdd);
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");

View File

@ -26,7 +26,9 @@
#include "utils/hash_map.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "abstract/abstract_value.h"
#include "base/core_ops.h"
namespace mindspore {
class SymbolicKeyInstance : public Value {
@ -97,37 +99,15 @@ struct SymbolicKeyInstanceEqual {
}
};
using EnvInstanceContentsMap =
mindspore::HashMap<SymbolicKeyInstancePtr, Any, SymbolicKeyInstanceHash, SymbolicKeyInstanceEqual>;
static inline AnfNodePtr NewEnviron(const FuncGraphPtr &fg) {
return fg->NewCNode({NewValueNode(prim::kPrimEnvironCreate)});
}
// Environment mapping keys to values.
// Keys are SymbolicKeyInstances, which represent nodes in the graph along
// with inferred properties.
class EnvInstance : public Value {
public:
friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &env);
static inline bool IsNewEnvironNode(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimEnvironCreate); }
EnvInstance() = default;
~EnvInstance() override = default;
MS_DECLARE_PARENT(EnvInstance, Value);
abstract::AbstractBasePtr ToAbstract() override {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<EnvInstance>(), std::make_shared<EnvType>());
}
bool operator==(const EnvInstance &other) const;
bool operator==(const Value &other) const override;
EnvInstance(const EnvInstance &v) : Value(v) {}
EnvInstance(EnvInstance &&v) = default;
EnvInstance &operator=(EnvInstance &&src) noexcept { return *this; };
std::size_t hash() const override {
// deterministic characteristic of member variables.
return tid();
}
};
using EnvInstancePtr = std::shared_ptr<EnvInstance>;
extern std::shared_ptr<EnvInstance> newenv;
static inline abstract::AbstractBasePtr MakeEnvironAbstract() {
return std::make_shared<abstract::AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
}
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_SYMBOLIC_H_

View File

@ -17,7 +17,7 @@
"""Data type for MindSpore."""
import numpy as np
from .._c_expression import typing, EnvInstance_
from .._c_expression import typing
from .._c_expression.typing import Type
__dtype__ = [
@ -161,7 +161,6 @@ _simple_types = {
np.float16: float16,
np.float32: float32,
np.float64: float64,
EnvInstance_: env_type,
}

View File

@ -20,7 +20,7 @@ from functools import partial
from types import FunctionType
from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_executor, _wrap_func
@ -29,7 +29,7 @@ from ..operations import _grad_ops
from .. import operations as P
from .. import signature as sig
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
def add_flags(fn=None, **flags):
@ -852,7 +852,7 @@ zip_operation = _ZipOperation('zip_operation')
env_get = MultitypeFuncGraph("env_get")
env_getitem = Primitive('env_getitem')
environ_get = Primitive('EnvironGet')
ref_to_embed = _grad_ops.RefToEmbed()
zeros_like = P.ZerosLike()
@ -860,4 +860,4 @@ zeros_like = P.ZerosLike()
@env_get.register("EnvType", "Tensor")
def _tensor_env_get(env, parameter):
"""Used to get env."""
return env_getitem(env, ref_to_embed(parameter), zeros_like(parameter))
return environ_get(env, ref_to_embed(parameter), zeros_like(parameter))

View File

@ -254,7 +254,7 @@ def _add_env(x, y):
Returns:
EnvType, equal to x + y.
"""
return F.env_add(x, y)
return F.environ_add(x, y)
@add.register("Tuple", "Tuple")

View File

@ -58,9 +58,6 @@ def _ones_like_sparse_tensor(x):
return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x))
newenv = base.EnvInstance_()
@ones_like_leaf.register("Function")
def _ones_like_func(x):
"""
@ -70,10 +67,10 @@ def _ones_like_func(x):
x (Function): x
Returns:
EnvInstance_, value is newenv.
A instance of EnvType.
"""
# Unused parameters are placeholders.
return newenv
return F.environ_create()
ones_like = base.HyperMap(ones_like_leaf)

View File

@ -37,9 +37,6 @@ def _zeros_like_bool(x):
return False
newenv = base.EnvInstance_()
@zeros_like_leaf.register("Function")
def _zeros_like_func(x):
"""
@ -49,10 +46,10 @@ def _zeros_like_func(x):
x (Function): x
Returns:
EnvInstance_, value is newenv.
A instance of EnvType.
"""
# Unused parameters are placeholders.
return newenv
return F.environ_create()
@zeros_like_leaf.register("Tensor")

View File

@ -456,9 +456,10 @@ zeros_like = P.ZerosLike()
distribute = Primitive('distribute')
embed = Primitive('embed')
ref_to_embed = _grad_ops.RefToEmbed()
env_setitem = Primitive('env_setitem')
env_getitem = Primitive('env_getitem')
env_add = Primitive('env_add')
environ_create = Primitive('EnvironCreate')
environ_set = Primitive('EnvironSet')
environ_get = Primitive('EnrironGet')
environ_add = Primitive('EnvironAdd')
J = Primitive('J')
SliceGetItem = Primitive("SliceGetItem")
switch = Primitive('Switch')

View File

@ -2047,7 +2047,7 @@ class RefToEmbed(Primitive):
Make a key from Ref.
The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref.
and get items by operation `EnvironGet` with the symbolic_key instance. The `Parameter` is a ref.
Inputs:
- **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.

View File

@ -346,19 +346,28 @@ TEST_F(TestOps, EmbedTest) {
ASSERT_EQ(prim->name(), kPrimEmbed->name());
}
TEST_F(TestOps, EnvSetItemTest) {
auto prim = std::make_shared<Primitive>("env_setitem");
ASSERT_EQ(prim->name(), kPrimEnvSetItem->name());
/// Feature: Check primitive name equivalence
/// Description: EnvironSet primitive name equivalence
/// Expectation: Equal
TEST_F(TestOps, EnvironSetTest) {
auto prim = std::make_shared<Primitive>("EnvironSet");
ASSERT_EQ(prim->name(), kPrimEnvironSet->name());
}
TEST_F(TestOps, EnvGetItemTest) {
auto prim = std::make_shared<Primitive>("env_getitem");
ASSERT_EQ(prim->name(), kPrimEnvGetItem->name());
/// Feature: Check primitive name equivalence
/// Description: EnvironGet primitive name equivalence
/// Expectation: Equal
TEST_F(TestOps, EnvironGetTest) {
auto prim = std::make_shared<Primitive>("EnvironGet");
ASSERT_EQ(prim->name(), kPrimEnvironGet->name());
}
TEST_F(TestOps, EnvAddest) {
auto prim = std::make_shared<Primitive>("env_add");
ASSERT_EQ(prim->name(), kPrimEnvAdd->name());
/// Feature: Check primitive name equivalence
/// Description: EnvironAdd primitive name equivalence
/// Expectation: Equal
TEST_F(TestOps, EnvironAddTest) {
auto prim = std::make_shared<Primitive>("EnvironAdd");
ASSERT_EQ(prim->name(), kPrimEnvironAdd->name());
}
// Neural Network

View File

@ -370,81 +370,81 @@ TEST_F(TestPrim, test_partial) {
ASSERT_TRUE(res->ToString() == expected->ToString());
}
// def test_env(x, y):
// return env_setitem(newenv, embed(x), y)
TEST_F(TestPrim, test_env_setitem) {
/// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
/// Description: Check inference result of primitive EnvironSet.
/// Expectation: Equal
TEST_F(TestPrim, test_environ_set) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvironSet, 3);
AbstractBasePtr abstract_env = ToAbstract(newenv);
AbstractBasePtr abstract_environ = MakeEnvironAbstract();
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
args_spec_list = {abstract_env, embed_x, abstract_y};
args_spec_list = {abstract_environ, embed_x, abstract_y};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*res == *exp);
}
// def test_env(x, y, z):
// e = env_setitem(newenv, embed(x), y)
// return env_getitem(e, embed(x), z)
TEST_F(TestPrim, test_env_getitem) {
/// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
/// Description: Check inference result of primitive EnvironGet.
/// Expectation: Equal
TEST_F(TestPrim, test_environ_get) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3);
AbstractBasePtr abstract_env = ToAbstract(newenv);
AbstractBasePtr abstract_environ = MakeEnvironAbstract();
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
args_spec_list = {abstract_env, embed_x, abstract_y};
args_spec_list = {abstract_environ, embed_x, abstract_y};
AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*res == *exp);
FuncGraphPtr graph_getitem = MakeFuncGraph(prim::kPrimEnvGetItem, 3);
FuncGraphPtr graph_environ_get = MakeFuncGraph(prim::kPrimEnvironGet, 3);
AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
args_spec_list = {res, embed_x, abstract_z};
res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract();
res = engine_->Run(graph_environ_get, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *abstract_x);
}
// def test_env(x, y, z):
// e1 = env_setitem(newenv, embed(x), y)
// e2 = env_setitem(newenv, embed(x), z)
// return env_add(e1, e2)
TEST_F(TestPrim, test_env_add) {
/// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
/// Description: Check inference result of primitive EnvironAdd.
/// Expectation: Equal
TEST_F(TestPrim, test_environ_add) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3);
AbstractBasePtr abstract_env = ToAbstract(newenv);
AbstractBasePtr abstract_environ = MakeEnvironAbstract();
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
args_spec_list = {abstract_env, embed_x, abstract_y};
args_spec_list = {abstract_environ, embed_x, abstract_y};
AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
AbstractBasePtr abstract_e1 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*abstract_e1 == *exp);
AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
args_spec_list = {abstract_env, embed_x, abstract_z};
args_spec_list = {abstract_environ, embed_x, abstract_z};
AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
AbstractBasePtr abstract_e2 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
ASSERT_TRUE(*abstract_e2 == *exp);
FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvironAdd, 2);
args_spec_list = {abstract_e1, abstract_e2};
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();

View File

@ -315,7 +315,7 @@ def test_while_if():
# should compile success without zeros_like_tensor args mismatch, the generated graph files
# should not contain env_getitem or env_setitem.
# should not contain EnvironGet or EnvironSet.
# InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set,
# so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like
# will have U as parameter, but the call site zeros_like(fv) don't have U argument.
@ -347,7 +347,7 @@ def test_grad_fv_and_insert_gradient_of():
net = FvAndInsertGradientNet()
input_data = Tensor(np.array([1.0], np.float32))
# if use grad_all_list, the generated graph will have env_setitem
# if use grad_all_list, the generated graph will have EnvironSet
# as gradient for inputs is constant zero, so it will depend on result of grad.
grad_net = grad_by_list(net, ParameterTuple(net.trainable_params()))
print(grad_net(input_data))