From 3872e29c3b33edd04952a2293ad5706226b925fe Mon Sep 17 00:00:00 2001 From: chenfei Date: Wed, 16 Mar 2022 16:14:16 +0800 Subject: [PATCH] add ut for tuple trans --- mindspore/ccsrc/debug/anf_ir_dump.cc | 108 +++++++++------ .../irpass/call_graph_tuple_transform.h | 18 +-- tests/st/control/test_call_tuple_transform.py | 99 +++++++++++++ tests/ut/cpp/optimizer/opt_test.cc | 131 +++++++++++++++++- .../gtest_input/optimizer/opt_test.py | 93 ++++++++++++- 5 files changed, 390 insertions(+), 59 deletions(-) create mode 100644 tests/st/control/test_call_tuple_transform.py diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index 49a10999a9a..112b4ac8d1a 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -683,6 +683,7 @@ void DumpSubgraph(const OrderedMap } void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) { + MS_LOG(INFO) << "Set dump config:" << str; static mindspore::HashMap dump_level_map = { {kDumpConfigLineLevel0, kOff}, {kDumpConfigLineLevel1, kTopStack}, {kDumpConfigLineLevel2, kWholeStack}}; auto it = dump_level_map.find(str); @@ -700,11 +701,64 @@ void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) { } } +std::shared_ptr> GetAllConfigStrings(const std::string &config_full_string) { + size_t start_pos = 0; + auto config_strings = std::make_shared>(); + // if '#' is the last char of str, the str is legal, so we use '<=' but not '<'. + while (start_pos <= config_full_string.size()) { + auto pos = config_full_string.find('#', start_pos); + if (pos == std::string::npos) { + pos = config_full_string.size(); + } + auto substr = config_full_string.substr(start_pos, pos - start_pos); + // Skip the '#' + start_pos = pos + 1; + if (substr.empty()) { + continue; + } + (void)config_strings->insert(substr); + } + return config_strings; +} + +bool ConfigsAreLegal(const std::shared_ptr> &config_strings) { + // Value 'int' is used to mark config group id + HashMap config_white_list = {{kDumpConfigLineLevel0, 0}, + {kDumpConfigLineLevel1, 0}, + {kDumpConfigLineLevel2, 0}, + {kDumpConfigDisableBackend, 1}, + {kDumpConfigEnablePassIR, 2}}; + // Key 'int' is config group id, value is the config. + HashMap config_groups; + for (const auto &config_string : *config_strings) { + auto config_white_list_it = config_white_list.find(config_string); + if (config_white_list_it == config_white_list.end()) { + std::ostringstream buffer; + buffer << "Support configs:\n" + << "[0]: " << kDumpConfigLineLevel0 << "\n" + << "[1]: " << kDumpConfigLineLevel1 << "\n" + << "[2]: " << kDumpConfigLineLevel2 << "\n" + << "[3]: " << kDumpConfigDisableBackend << "\n" + << "[4]: " << kDumpConfigEnablePassIR; + MS_LOG(WARNING) << "Illegal dump config:\n" << config_string << "\n" << buffer.str(); + return false; + } + auto group_id = config_white_list_it->second; + // Check conflict configs. + auto config_groups_it = config_groups.find(group_id); + if (config_groups_it != config_groups.end()) { + const auto &record_config = config_groups_it->second; + MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: [" << record_config << "] and [" << config_string + << "].\n" + << "Please keep only one of them."; + return false; + } + config_groups[group_id] = config_string; + } + return true; +} + DumpConfig GetDumpConfig() { - static std::vector> config_white_list = { - {kDumpConfigLineLevel0, kDumpConfigLineLevel1, kDumpConfigLineLevel2}, - {kDumpConfigDisableBackend}, - {kDumpConfigEnablePassIR}}; static DumpConfig dump_config = DumpConfig(); static bool parsed = false; if (parsed) { @@ -713,9 +767,6 @@ DumpConfig GetDumpConfig() { parsed = true; // Start parse config. std::string str(common::GetEnv("MS_DEV_DUMP_IR_CONFIG")); - std::vector>> configs = {std::make_shared>(), - std::make_shared>(), - std::make_shared>()}; auto constexpr max_string_len = 100; if (str.size() > max_string_len) { MS_LOG(WARNING) << "Dump ir config length exceed max length: " << max_string_len; @@ -724,45 +775,12 @@ DumpConfig GetDumpConfig() { if (str.empty()) { return dump_config; } - size_t start_pos = 0; - // if '#' is the last char of str, the str is illegal, so we use '<=' but not '<'. - while (start_pos <= str.size()) { - auto pos = str.find('#', start_pos); - if (pos == std::string::npos) { - pos = str.size(); - } - auto substr = str.substr(start_pos, pos - start_pos); - start_pos = pos + 1; - bool is_illegal_config = true; - for (size_t i = 0; i < config_white_list.size(); i++) { - if (config_white_list[i].find(substr) != config_white_list[i].end()) { - is_illegal_config = false; - (void)configs[i]->insert(substr); - if (configs[i]->size() > 1) { - std::ostringstream buffer; - (void)std::for_each(configs[i]->begin(), configs[i]->end(), [&buffer](const std::string &config) { - buffer << "\n" << config; - }); - MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: " << buffer.str() << "\n" - << "Please keep only one of them."; - return dump_config; - } - } - } - if (is_illegal_config) { - std::ostringstream buffer; - buffer << "Support configs:\n" - << "[0]: " << kDumpConfigLineLevel0 << "\n" - << "[1]: " << kDumpConfigLineLevel1 << "\n" - << "[2]: " << kDumpConfigLineLevel2 << "\n" - << "[3]: " << kDumpConfigDisableBackend << "\n" - << "[4]: " << kDumpConfigEnablePassIR; - MS_LOG(WARNING) << "Illegal dump config:\n" << substr << "\n" << buffer.str(); - return {}; - } + auto config_strings = GetAllConfigStrings(str); + if (!ConfigsAreLegal(config_strings)) { + return dump_config; } - for (auto &config : configs) { - SetDumpConfigByString(*config->begin(), &dump_config); + for (const auto &config : *config_strings) { + SetDumpConfigByString(config, &dump_config); } return dump_config; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h b/mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h index 691c02348a4..f01e5ec80d3 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/call_graph_tuple_transform.h @@ -87,7 +87,7 @@ class GraphTupleTransform : public AnfVisitor { GraphTupleParamTransform graph_transform_; }; -// {,kPrimPartial, G, Tuple_Xs} +// {PrimPartial, G, Tuple_Xs} // => // {kPrimPartial, G, TupleGetItem{Tuple_Xs,0}, TupleGetItem{Tuple_Xs,1}, ..., TupleGetItem{Tuple_Xs,n}} // transform partial's tuple binding args to flat inputs. @@ -102,12 +102,12 @@ class PartialTupleArgTransform : public AnfVisitor { auto partial = node->cast(); const auto &partial_inputs = partial->inputs(); const auto &fg = partial->func_graph(); - // And primitive and function value node into args. constexpr auto kPartialFirstArgIndex = 2; - auto new_args = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex); - auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_args); + // Put ValueNode and ValueNode into new_inputs. + auto new_inputs = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex); + auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_inputs); if (change) { - auto new_partial = fg->NewCNode(new_args); + auto new_partial = fg->NewCNode(new_inputs); new_partial->set_abstract(partial->abstract()); return new_partial; } @@ -132,11 +132,11 @@ class CallTupleArgTransform : public AnfVisitor { const auto &call_inputs = call_node->inputs(); const auto &fg = call_node->func_graph(); MS_EXCEPTION_IF_NULL(fg); - // Add function value node into args. - auto new_args = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1); - auto change = FlattenArgs(fg, call_inputs, 1, &new_args); + // Put ValueNode into inputs. + auto new_inputs = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1); + auto change = FlattenArgs(fg, call_inputs, 1, &new_inputs); if (change) { - auto new_call = fg->NewCNode(new_args); + auto new_call = fg->NewCNode(new_inputs); new_call->set_abstract(call_node->abstract()); return new_call; } diff --git a/tests/st/control/test_call_tuple_transform.py b/tests/st/control/test_call_tuple_transform.py new file mode 100644 index 00000000000..d02de12ce4f --- /dev/null +++ b/tests/st/control/test_call_tuple_transform.py @@ -0,0 +1,99 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore as ms +from mindspore import context +from mindspore.ops import operations as P +from mindspore.common.api import ms_function +from mindspore.common.tensor import Tensor +import mindspore.nn as nn + +import numpy as np +import pytest + + +class MAPPOCriticNet(nn.Cell): + def __init__(self): + super().__init__() + self.linear1_actor = nn.Dense(54, # input local obs shape + 64, + weight_init='XavierUniform', + # paper uses orthogonal with gain 5/3 for every dense123 + has_bias=False, + activation=nn.Tanh()) + + def construct(self, x): + # Feature Extraction + x = self.linear1_actor(x) + + return x + + +class MAPPOActor(nn.Cell): + + def __init__(self, actor_net): + super().__init__() + self.actor_net = actor_net + + def construct(self, inputs_data): + _, global_obs = inputs_data + out = self.actor_net(global_obs) + + return out + + +class TestClass(nn.Cell): + def __init__(self, actor_list): + super().__init__() + self.zero = Tensor(0, ms.int32) + self.actor_list = actor_list + self.less = P.Less() + self.zeros = P.Zeros() + + def train(self): + state = Tensor(np.random.random((3, 128, 18)), ms.float32) + init_global_obs = self.zeros((128, 54), ms.float32) + out = self.test(state, init_global_obs) + return out + + @ms_function + def test(self, state, init_global_obs): + num_agent = self.zero + while self.less(num_agent, 3): + samples = (state[num_agent], init_global_obs) + self.actor_list[num_agent](samples) + num_agent += 1 + + return num_agent + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_net(): + """ + Feature: Tuple arg transform. + Description: Test the pass: transform tuple arg to tensor arg. + Expectation: Compile done without error. + """ + context.set_context(mode=context.GRAPH_MODE, save_graphs=False, save_graphs_path="./graph_ir") + actor_list = nn.CellList() + for _ in range(3): + net = MAPPOCriticNet() + actor = MAPPOActor(net) + actor_list.append(actor) + test = TestClass(actor_list) + graph_out = test.train() + + assert np.allclose(graph_out.asnumpy(), graph_out.asnumpy(), 0.0001, 0.0001) diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index 7800c70c99f..ce592bdb55d 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,10 +22,12 @@ #include "ir/anf.h" #include "ir/visitor.h" #include "ir/func_graph_cloner.h" +#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/opt.h" #include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/arithmetic_simplify.h" +#include "pipeline/jit/action.h" #include "debug/draw.h" #include "frontend/operator/ops.h" @@ -107,6 +109,8 @@ class TestOptOpt : public UT::Common { FuncGraphPairMapEquiv equiv_graph; NodeMapEquiv equiv_node; + irpass::OptimizeIRPassLib irpass_lib; + static const PrimitivePtr P; static const PrimitivePtr Q; static const PrimitivePtr R; @@ -115,6 +119,7 @@ class TestOptOpt : public UT::Common { SubstitutionPtr elim_R; SubstitutionPtr idempotent_P; SubstitutionPtr Qct_to_P; + SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_; }; const PrimitivePtr TestOptOpt::P = std::make_shared("P"); @@ -148,8 +153,8 @@ TEST_F(TestOptOpt, ElimTwo) { } TEST_F(TestOptOpt, ElimR) { - FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1"); - FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after"); + FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_r", "before_1"); + FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_r", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); @@ -208,5 +213,125 @@ TEST_F(TestOptOpt, CSE) { ASSERT_EQ(manager2->all_nodes().size(), 12); } +size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) { + // Check tuple params and tuple args. + auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude); + size_t tuple_arg_param_num = 0; + auto tuple_accumulate_func = [](size_t prev_num, const AnfNodePtr &node) -> size_t { + auto abs = node->abstract(); + MS_EXCEPTION_IF_NULL(abs); + return abs->isa() ? prev_num + 1 : prev_num; + }; + for (const auto &node : all_nodes) { + // Count func graph call tuple args. + if (node->isa() && !IsValueNode(node->cast()->input(0))) { + auto call_node = node->cast(); + tuple_arg_param_num = std::accumulate(call_node->inputs().begin() + 1, call_node->inputs().end(), + tuple_arg_param_num, tuple_accumulate_func); + } + // Count partial tuple args. + if (IsPrimitiveCNode(node, prim::kPrimPartial)) { + auto partial = node->cast(); + constexpr auto kPartialFirstArgIdx = 2; + tuple_arg_param_num = std::accumulate(partial->inputs().begin() + kPartialFirstArgIdx, partial->inputs().end(), + tuple_arg_param_num, tuple_accumulate_func); + } + + // Count tuple params. + if (IsValueNode(node)) { + auto fg = GetValueNode(node); + tuple_arg_param_num = + std::accumulate(fg->parameters().begin(), fg->parameters().end(), tuple_arg_param_num, tuple_accumulate_func); + } + } + return tuple_arg_param_num; +} + +// Feature: Switch call tuple arg transform. +// Description: Test switch call's tuple arg transform.This case include partial's tuple arg and the call's tuple arg in +// the same time. +// Expectation: All tuple args are correctly transformed to tensor args. +TEST_F(TestOptOpt, SwitchPartialTupleTrans) { + FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_partial_arg"); + ASSERT_TRUE(nullptr != test_graph); + + FuncGraphManagerPtr manager1 = Manage(test_graph); + pipeline::ResourcePtr res = std::make_shared(); + std::vector args_spec; + + // Renormalize firstly. + auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); + ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); + + // Flatten tuple param and args. + OptimizerPtr optimizer = std::make_shared("ut_test", res); + SubstitutionList transform(std::vector({tuple_flatten})); + transform(renormalized_fg, optimizer); + + // Renormalize again. + auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); + ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); + + abstract::AnalysisResultCacheMgr::GetInstance().Clear(); + abstract::AnalysisContext::ClearContext(); +} + +// Feature: Switch layer call tuple arg transform. +// Description: Test switch layer call's tuple arg transform.This case include partial's tuple arg and the partial's +// tensor arg in the same time. +// Expectation: All tuple args are correctly transformed to tensor args. +TEST_F(TestOptOpt, SwitchLayerPartialTupleTrans) { + FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_layer_partial_arg"); + ASSERT_TRUE(nullptr != test_graph); + + FuncGraphManagerPtr manager1 = Manage(test_graph); + pipeline::ResourcePtr res = std::make_shared(); + std::vector args_spec; + + // Renormalize firstly. + auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); + ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); + + // Flatten tuple param and args. + OptimizerPtr optimizer = std::make_shared("ut_test", res); + SubstitutionList transform(std::vector({tuple_flatten})); + transform(renormalized_fg, optimizer); + + // Renormalize again. + auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); + ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); + + abstract::AnalysisResultCacheMgr::GetInstance().Clear(); + abstract::AnalysisContext::ClearContext(); +} + +// Feature: Single graph call tuple arg transform. +// Description: Test single graph call's tuple arg transform.This case include tuple in tuple args. +// Expectation: All tuple args are correctly transformed to tensor args. +TEST_F(TestOptOpt, SimpleCallTupleTupleTrans) { + FuncGraphPtr test_graph = + getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_simple_call_tuple_in_tuple_arg"); + ASSERT_TRUE(nullptr != test_graph); + + FuncGraphManagerPtr manager1 = Manage(test_graph); + pipeline::ResourcePtr res = std::make_shared(); + std::vector args_spec; + + // Renormalize firstly. + auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); + ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); + + // Flatten tuple param and args. + OptimizerPtr optimizer = std::make_shared("ut_test", res); + SubstitutionList transform(std::vector({tuple_flatten})); + transform(renormalized_fg, optimizer); + + // Renormalize again. + auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); + ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); + + abstract::AnalysisResultCacheMgr::GetInstance().Clear(); + abstract::AnalysisContext::ClearContext(); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 7bdea479989..a2f8c160c48 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -16,9 +16,11 @@ import numpy as np from mindspore import Tensor +from mindspore import dtype as mstype from mindspore.ops import Primitive from mindspore.ops import _constants as Constants from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.ops.operations import _grad_ops as G # pylint: disable=unused-variable @@ -68,8 +70,12 @@ def test_add_zero(tag): return fns[tag] -def test_elimR(tag): - """ test_elimR """ +def test_elim_r(tag): + """ + Feature: optimizer. + Description: test elimi R. + Expectation: run case with no exception. + """ R = Primitive('R') fns = FnDict() @@ -495,6 +501,7 @@ def test_elim_transpose(tag): return fns[tag] + def test_elim_depend_value(tag): """ test_elim_depend_value """ fns = FnDict() @@ -1203,3 +1210,85 @@ def test_sparse_tensor(tag): return z return fns[tag] + + +# Test ut for file: call_graph_tuple_transform.h. +def test_tuple_flatten(tag): + """ + Feature: optimizer. + Description: test cases for pass: graph_tuple_transform. + Expectation: the tuple args and parameters are successfully flattened by the pass. + """ + fns = FnDict() + w = Tensor(np.random.randn(64, 3, 7, 7).astype(np.float32)) + x = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32)) + y = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32)) + + p = Tensor(3, mstype.float32) + + out_channel = 64 + kernel_size = 7 + conv = P.Conv2D(out_channel, + kernel_size, + mode=1, + pad_mode="valid", + pad=0, + stride=1, + dilation=1, + group=1) + pow_ops = P.Pow() + + @fns + def test_flatten_switch_partial_arg(): + def called_graph_with_tuple(tuple_x, tuple_y): + return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) + conv(F.tuple_getitem(tuple_y, 0), + F.tuple_getitem(tuple_y, 1)) + + # Add tuple args in partial args. + func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p))) + func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p))) + cond = x < y + + switch_node = F.switch(cond, func1, func2) + # Add tuple args in call args. + return switch_node((pow_ops(x, p), pow_ops(w, p))) + + index = Tensor(1, mstype.int32) + + @fns + def test_flatten_switch_layer_partial_arg(): + def called_graph_with_tuple(tuple_x): + return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) + + def called_graph_no_tuple(param1, param2): + return conv(param1, param2) + + # Add tuple args in partial + func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p))) + func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p))) + # Add tensor args in partial + func3 = F.partial(called_graph_no_tuple, pow_ops(x, p), pow_ops(w, p)) + switch_node = F.switch_layer(pow_ops(index, index), (func1, func2, func3)) + return switch_node() + + @fns + def test_flatten_simple_call_tuple_in_tuple_arg(): + def called_graph_with_tuple(tuple_x, tuple_tuple_y, tensor_z): + result1 = conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) + tuple_0 = F.tuple_getitem(tuple_tuple_y, 0) + result2 = conv(F.tuple_getitem(tuple_0, 0), F.tuple_getitem(tuple_0, 1)) + tensor_1 = F.tuple_getitem(tuple_tuple_y, 1) + result3 = conv(tensor_1, tensor_z) + return result1 + result2 + result3 + + # Tuple arg. + tuple_x_arg = (pow_ops(x, p), pow_ops(w, p)) + # TupleTuple arg. + tuple_0_arg = (pow_ops(x, p), pow_ops(w, p)) + tensor_1_arg = pow_ops(x, p) + tuple_tuple_y_arg = (tuple_0_arg, tensor_1_arg) + # TensorArg + tensor_z_arg = pow_ops(w, p) + return called_graph_with_tuple(tuple_x_arg, tuple_tuple_y_arg, tensor_z_arg) + + return fns[tag]