From 3b87ff1bb23b4279bc0b7d9c01ceec7e4edc03f2 Mon Sep 17 00:00:00 2001 From: zengzitao Date: Thu, 29 Sep 2022 10:29:21 +0800 Subject: [PATCH] lu ascend adapt --- .../common/graph_kernel/adapter/expander.cc | 2 + .../graph_kernel/core/graph_kernel_utils.h | 1 + .../graph_kernel/expanders/custom_op_utils.h | 43 +++++ .../ccsrc/common/graph_kernel/expanders/lu.cc | 178 ++++++++++++++++++ tests/st/ops/graph_kernel/custom/test_lu.py | 59 ++++++ 5 files changed, 283 insertions(+) create mode 100644 mindspore/ccsrc/common/graph_kernel/expanders/lu.cc create mode 100644 tests/st/ops/graph_kernel/custom/test_lu.py diff --git a/mindspore/ccsrc/common/graph_kernel/adapter/expander.cc b/mindspore/ccsrc/common/graph_kernel/adapter/expander.cc index 25c278618a6..410b09163a0 100644 --- a/mindspore/ccsrc/common/graph_kernel/adapter/expander.cc +++ b/mindspore/ccsrc/common/graph_kernel/adapter/expander.cc @@ -63,6 +63,7 @@ ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) { {prim::kPrimArgMaxWithValue->name(), {ArgWithValueDeco::Creator}}, {prim::kPrimArgMinWithValue->name(), {ArgWithValueDeco::Creator}}, {prim::kPrimSolveTriangular->name(), {ProcessCustomOpDeco::Creator}}, + {prim::kPrimLU->name(), {ProcessCustomOpDeco::Creator}}, }; const auto iter = creators.find(GetCNodePrimitive(node)->name()); if (iter != creators.end()) { @@ -100,6 +101,7 @@ bool CanExpandFallback(const AnfNodePtr &node) { {kAllTarget, OpLevel_0, prim::kPrimAdam}, // some ops including custom op are only used expand fallbak on Ascend. {kAscendDevice, OpLevel_0, prim::kPrimSolveTriangular}, + {kAscendDevice, OpLevel_0, prim::kPrimLU}, // disabled {kAllTarget, OpLevel_1, prim::kPrimAddN}, {kAllTarget, OpLevel_1, prim::kPrimErfc}, diff --git a/mindspore/ccsrc/common/graph_kernel/core/graph_kernel_utils.h b/mindspore/ccsrc/common/graph_kernel/core/graph_kernel_utils.h index 6cda63f2596..f168dadb9c8 100644 --- a/mindspore/ccsrc/common/graph_kernel/core/graph_kernel_utils.h +++ b/mindspore/ccsrc/common/graph_kernel/core/graph_kernel_utils.h @@ -32,6 +32,7 @@ GVAR_DEF(PrimitivePtr, kPrimElemAny, std::make_shared("ElemAny")); GVAR_DEF(PrimitivePtr, kPrimLayoutTransform, std::make_shared("LayoutTransform")); GVAR_DEF(PrimitivePtr, kPrimStridedSliceOnnx, std::make_shared("StridedSliceOnnx")); GVAR_DEF(PrimitivePtr, kPrimSolveTriangular, std::make_shared("SolveTriangular")); +GVAR_DEF(PrimitivePtr, kPrimLU, std::make_shared("LU")); } // namespace mindspore::prim namespace mindspore::graphkernel { diff --git a/mindspore/ccsrc/common/graph_kernel/expanders/custom_op_utils.h b/mindspore/ccsrc/common/graph_kernel/expanders/custom_op_utils.h index 04bcd4aa3b3..fae91d495c3 100644 --- a/mindspore/ccsrc/common/graph_kernel/expanders/custom_op_utils.h +++ b/mindspore/ccsrc/common/graph_kernel/expanders/custom_op_utils.h @@ -24,6 +24,7 @@ constexpr int64_t kBlock = 16; const char kFuncType[] = "hybrid"; const char kTrsmName[] = "trsm"; +const char kLUName[] = "lu_decomp"; const std::map kTrsmFuncStrMap = { {"trsmL_N_D", @@ -54,13 +55,55 @@ const std::map kTrsmFuncStrMap = { " for k in vectorize(16):\n" " inverse_0[i, l * 16 + k] = a[i, j] * b[j, l * 16 + k]\n" " b[i, l * 16 + k] = b[i, l * 16 + k] - inverse_0[i, l * 16 + k]\n" + " return b\n"}, + {"trsmU_T", + "def trsmU_T(a, b):\n" + " row = b.shape[0]\n" + " col = b.shape[1]\n" + " inverse_0 = allocate((col, ), b.dtype)\n" + " tmp = allocate((col, ), b.dtype)\n" + " for i in range(row):\n" + " for j in range(col):\n" + " tmp[j] = a[j, j]\n" + " b[i, j] = b[i, j] / tmp[j]\n" + " for k in vectorize(col):\n" + " inverse_0[k] = b[i, j] * a[j, k]\n" + " for k in vectorize(j + 1):\n" + " inverse_0[k] = (0.0)\n" + " for k in vectorize(col):\n" + " b[i, k] = b[i, k] - inverse_0[k]\n" " return b\n"}}; +const std::map kLUFuncStrMap = {{"lu_decomp", + "def lu_decomp(a):\n" + " out_0 = allocate(a.shape, a.dtype)\n" + " out_1 = allocate(a.shape, a.dtype)\n" + " for i in range(a.shape[0]):\n" + " for j in range(a.shape[1]):\n" + " if j > i:\n" + " a[j, i] = a[j, i] / a[i, i]\n" + " for k in range(a.shape[0]):\n" + " for l in vectorize(a.shape[1]):\n" + " out_0[k, l] = a[k, i]\n" + " out_1[k, l] = out_0[k, l] * a[i, l]\n" + " if k > i and l > i:\n" + " a[k, l] = a[k, l] - out_1[k, l]\n" + " return a\n"}}; + const char kTrsmLAttrs[] = "{\"pragma_enable_reschedule\": false," " \"enable_hoist_cond_write\": false," " \"enable_approximate_read\": true," " \"enable_post_poly_loop_partition\": false," " \"enable_polytops\": \"always\"}"; + +const char kLUAttrs[] = + "{\"pragma_enable_reschedule\": false," + " \"enable_hoist_cond_write\": false," + " \"enable_double_buffer\": false," + " \"enable_pre_poly_loop_partition\": false," + " \"enable_post_poly_loop_partition\": false," + " \"enable_to_three_address\": false," + " \"enable_polytops\": \"always\"}"; } // namespace mindspore::graphkernel::expanders #endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_CUS_UTILS_H_ diff --git a/mindspore/ccsrc/common/graph_kernel/expanders/lu.cc b/mindspore/ccsrc/common/graph_kernel/expanders/lu.cc new file mode 100644 index 00000000000..ef510dbb026 --- /dev/null +++ b/mindspore/ccsrc/common/graph_kernel/expanders/lu.cc @@ -0,0 +1,178 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +#include "common/graph_kernel/expanders/op_desc_registry.h" +#include "common/graph_kernel/expanders/utils.h" +#include "common/graph_kernel/expanders/custom_op_utils.h" + +namespace mindspore::graphkernel::expanders { +class LU : public OpDesc { + public: + LU() {} + ~LU() = default; + + protected: + std::vector GetIndicesValue(int64_t u_left, int64_t u_right, int64_t v_left, int64_t v_right) { + std::vector indices_value; + for (int64_t u = u_left; u < u_right; ++u) { + for (int64_t v = v_left; v < v_right; ++v) { + (void)indices_value.emplace_back(u); + (void)indices_value.emplace_back(v); + } + } + return indices_value; + } + std::vector GetEyesValue(int64_t num) { + std::vector eyes_value; + for (int64_t i = 0; i < num; ++i) { + for (int64_t j = 0; j < num; ++j) { + if (i == j) { + (void)eyes_value.emplace_back(1); + } else { + (void)eyes_value.emplace_back(0); + } + } + } + return eyes_value; + } + std::vector GetRangeValue(int64_t num) { + std::vector range_value; + for (int i = 0; i < static_cast(num); ++i) { + (void)range_value.emplace_back(i); + } + return range_value; + } + NodePtrList Expand(const NodePtrList &inputs) override { + auto input_x = inputs[0]; + auto num = input_x->shape[0]; + auto loop_count = static_cast(num / kBlock); + std::vector strides{1, 1}; + + // lu dsl implementation + for (int64_t i = 0; i < loop_count; ++i) { + std::vector begin_1{i * kBlock, i * kBlock}; + std::vector end_1{(i + 1) * kBlock, (i + 1) * kBlock}; + auto stride_1 = gb.StridedSlice(input_x, begin_1, end_1, strides); + std::string lu_func_type = kFuncType; + std::string lu_func_name = kLUName; + // get the name of lu via its attrs + auto lu_iter = kLUFuncStrMap.find(lu_func_name); + std::string lu_func_source_str = lu_iter->second; + size_t lu_inplace_assign_output = 0; + std::string lu_func_compile_attrs = kLUAttrs; + + auto custom_lu_decomp_result = + gb.Custom({stride_1}, {stride_1->shape, stride_1->type, stride_1->format}, lu_func_name, lu_func_type, + lu_func_source_str, lu_inplace_assign_output, lu_func_compile_attrs); + ShapeVector ind_shape{kBlock, kBlock, 2}; + std::vector ind_value = GetIndicesValue(i * kBlock, (i + 1) * kBlock, i * kBlock, (i + 1) * kBlock); + + auto ind_tensor = std::make_shared(kNumberTypeInt64, ind_shape, &ind_value[0], kNumberTypeInt64); + auto first_indices = gb.Value(ind_tensor); + input_x = gb.Emit("ScatterNdUpdate", {input_x, first_indices, custom_lu_decomp_result}, + {{"use_locking", MakeValue(false)}}); + if (i < loop_count - 1) { + std::vector begin_2{i * kBlock, (i + 1) * kBlock}; + std::vector end_2{(i + 1) * kBlock, num}; + auto stride_2 = gb.StridedSlice(input_x, begin_1, end_1, strides); + auto stride_3 = gb.StridedSlice(input_x, begin_2, end_2, strides); + std::string trsmL_off_diag_func_type = kFuncType; + std::string trsmL_off_diag_func_name = kTrsmName; + // get the name of trsmL_off_diag via its attrs + trsmL_off_diag_func_name = trsmL_off_diag_func_name + "L_N_U"; + auto trsmL_off_diag_iter = kTrsmFuncStrMap.find(trsmL_off_diag_func_name); + std::string trsmL_off_diag_source_str = trsmL_off_diag_iter->second; + size_t trsmL_off_diag_inplace_assign_output = 1; + std::string trsmL_off_diag_compile_attrs = kTrsmLAttrs; + auto custom_trsmL_off_diag_result = + gb.Custom({stride_2, stride_3}, {stride_3->shape, stride_3->type, stride_3->format}, trsmL_off_diag_func_name, + trsmL_off_diag_func_type, trsmL_off_diag_source_str, trsmL_off_diag_inplace_assign_output, + trsmL_off_diag_compile_attrs); + ShapeVector sec_indices_shape{kBlock, num - (i + 1) * kBlock, 2}; + std::vector sec_indicse_value = GetIndicesValue(i * kBlock, (i + 1) * kBlock, (i + 1) * kBlock, num); + + auto sec_indices_tensor = std::make_shared(kNumberTypeInt64, sec_indices_shape, + &sec_indicse_value[0], kNumberTypeInt64); + auto sec_indices = gb.Value(sec_indices_tensor); + input_x = gb.Emit("ScatterNdUpdate", {input_x, sec_indices, custom_trsmL_off_diag_result}, + {{"use_locking", MakeValue(false)}}); + + std::vector begin_3{(i + 1) * kBlock, i * kBlock}; + std::vector end_3{num, (i + 1) * kBlock}; + auto stride_4 = gb.StridedSlice(input_x, begin_1, end_1, strides); + auto stride_5 = gb.StridedSlice(input_x, begin_3, end_3, strides); + std::string trsmUT_func_type = kFuncType; + std::string trsmUT_func_name = kTrsmName; + // get the name of trsmUT via its attrs + trsmUT_func_name = trsmUT_func_name + "U_T"; + auto trsmUT_iter = kTrsmFuncStrMap.find(trsmUT_func_name); + std::string trsmUT_source_str = trsmUT_iter->second; + size_t trsmUT_inplace_assign_output = 1; + std::string trsmUT_compile_attrs = kTrsmLAttrs; + auto custom_trsmUT_result = + gb.Custom({stride_4, stride_5}, {stride_5->shape, stride_5->type, stride_5->format}, trsmUT_func_name, + trsmUT_func_type, trsmUT_source_str, trsmUT_inplace_assign_output, trsmUT_compile_attrs); + ShapeVector third_indices_shape{num - (i + 1) * kBlock, kBlock, 2}; + std::vector thi_indices_v = GetIndicesValue((i + 1) * kBlock, num, i * kBlock, (i + 1) * kBlock); + + auto thi_indices_tensor = + std::make_shared(kNumberTypeInt64, third_indices_shape, &thi_indices_v[0], kNumberTypeInt64); + auto thi_indices = gb.Value(thi_indices_tensor); + input_x = + gb.Emit("ScatterNdUpdate", {input_x, thi_indices, custom_trsmUT_result}, {{"use_locking", MakeValue(false)}}); + + std::vector begin_4{(i + 1) * kBlock, (i + 1) * kBlock}; + std::vector end_4{num, num}; + std::vector begin_5{i * kBlock, (i + 1) * kBlock}; + std::vector end_5{(i + 1) * kBlock, num}; + auto stride_6 = gb.StridedSlice(input_x, begin_4, end_4, strides); + auto stride_7 = gb.StridedSlice(input_x, begin_3, end_3, strides); + auto stride_8 = gb.StridedSlice(input_x, begin_5, end_5, strides); + // on ascend, matmul's inputs must be fp16 + stride_7 = gb.Cast(stride_7, kNumberTypeFloat16); + stride_8 = gb.Cast(stride_8, kNumberTypeFloat16); + auto matmul_stride_7_8 = gb.MatMul(stride_7, stride_8); + matmul_stride_7_8 = gb.Cast(matmul_stride_7_8, kNumberTypeFloat32); + auto final_update = gb.Sub(stride_6, matmul_stride_7_8); + ShapeVector final_indices_shape{num - (i + 1) * kBlock, num - (i + 1) * kBlock, 2}; + std::vector final_indicse_value = GetIndicesValue((i + 1) * kBlock, num, (i + 1) * kBlock, num); + + auto final_indices_tensor = std::make_shared(kNumberTypeInt64, final_indices_shape, + &final_indicse_value[0], kNumberTypeInt64); + auto f_indices = gb.Value(final_indices_tensor); + input_x = gb.Emit("ScatterNdUpdate", {input_x, f_indices, final_update}, {{"use_locking", MakeValue(false)}}); + } + } + ShapeVector eyes_shape{num, num}; + std::vector eyes_value = GetEyesValue(num); + auto eyes_tensor = std::make_shared(kNumberTypeInt32, eyes_shape, &eyes_value[0], kNumberTypeInt32); + auto eyes = gb.Value(eyes_tensor); + auto eyes_cnode = gb.Reshape(eyes, eyes_shape); + + ShapeVector r_shape{num}; + std::vector r_value = GetRangeValue(num); + auto range_tensor = std::make_shared(kNumberTypeInt32, r_shape, &r_value[0], kNumberTypeInt32); + auto range = gb.Value(range_tensor); + auto range_cnode = gb.Reshape(range, r_shape); + return {input_x, range_cnode, eyes_cnode}; + } +}; +EXPANDER_OP_DESC_REGISTER("LU", LU); +} // namespace mindspore::graphkernel::expanders diff --git a/tests/st/ops/graph_kernel/custom/test_lu.py b/tests/st/ops/graph_kernel/custom/test_lu.py new file mode 100644 index 00000000000..d5062ce7e63 --- /dev/null +++ b/tests/st/ops/graph_kernel/custom/test_lu.py @@ -0,0 +1,59 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore.scipy.ops import LU + + +class NetLU(nn.Cell): + def __init__(self): + super(NetLU, self).__init__() + self.ops = LU() + + def construct(self, a): + return self.ops(a) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_lu(): + """ + Feature: lu op use custom compile test on graph mode. + Description: test lu op on graph mode + Expectation: the result equal to expect. + """ + num = 32 + one_tensor = np.zeros((num, num)) + for i in range(num): + for j in range(num): + one_tensor[i, j] = min(min(i, j), 8) + 1 + upper_matrix = np.triu(one_tensor).astype(np.float32) + lower_matrix = np.tril(np.ones((num, num))).astype(np.float32) + input1 = np.dot(lower_matrix, upper_matrix) + expect = upper_matrix + lower_matrix - np.eye(num) + real_input = Tensor(input1, mstype.float32) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + net = NetLU() + result = net(real_input) + rtol = 0.001 + atol = 0.001 + assert np.allclose(result[0].asnumpy(), expect, rtol, atol, equal_nan=True)