forked from mindspore-Ecosystem/mindspore
!43188 GraphKernel LU op use custom
Merge pull request !43188 from ZengZitao/lu_adapt
This commit is contained in:
commit
04ede821d9
|
@ -63,6 +63,7 @@ ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
|
|||
{prim::kPrimArgMaxWithValue->name(), {ArgWithValueDeco::Creator}},
|
||||
{prim::kPrimArgMinWithValue->name(), {ArgWithValueDeco::Creator}},
|
||||
{prim::kPrimSolveTriangular->name(), {ProcessCustomOpDeco::Creator}},
|
||||
{prim::kPrimLU->name(), {ProcessCustomOpDeco::Creator}},
|
||||
};
|
||||
const auto iter = creators.find(GetCNodePrimitive(node)->name());
|
||||
if (iter != creators.end()) {
|
||||
|
@ -100,6 +101,7 @@ bool CanExpandFallback(const AnfNodePtr &node) {
|
|||
{kAllTarget, OpLevel_0, prim::kPrimAdam},
|
||||
// some ops including custom op are only used expand fallbak on Ascend.
|
||||
{kAscendDevice, OpLevel_0, prim::kPrimSolveTriangular},
|
||||
{kAscendDevice, OpLevel_0, prim::kPrimLU},
|
||||
// disabled
|
||||
{kAllTarget, OpLevel_1, prim::kPrimAddN},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimErfc},
|
||||
|
|
|
@ -32,6 +32,7 @@ GVAR_DEF(PrimitivePtr, kPrimElemAny, std::make_shared<Primitive>("ElemAny"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimLayoutTransform, std::make_shared<Primitive>("LayoutTransform"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStridedSliceOnnx, std::make_shared<Primitive>("StridedSliceOnnx"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSolveTriangular, std::make_shared<Primitive>("SolveTriangular"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLU, std::make_shared<Primitive>("LU"));
|
||||
} // namespace mindspore::prim
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
|
|
@ -24,6 +24,7 @@ constexpr int64_t kBlock = 16;
|
|||
|
||||
const char kFuncType[] = "hybrid";
|
||||
const char kTrsmName[] = "trsm";
|
||||
const char kLUName[] = "lu_decomp";
|
||||
|
||||
const std::map<std::string, std::string> kTrsmFuncStrMap = {
|
||||
{"trsmL_N_D",
|
||||
|
@ -54,13 +55,55 @@ const std::map<std::string, std::string> kTrsmFuncStrMap = {
|
|||
" for k in vectorize(16):\n"
|
||||
" inverse_0[i, l * 16 + k] = a[i, j] * b[j, l * 16 + k]\n"
|
||||
" b[i, l * 16 + k] = b[i, l * 16 + k] - inverse_0[i, l * 16 + k]\n"
|
||||
" return b\n"},
|
||||
{"trsmU_T",
|
||||
"def trsmU_T(a, b):\n"
|
||||
" row = b.shape[0]\n"
|
||||
" col = b.shape[1]\n"
|
||||
" inverse_0 = allocate((col, ), b.dtype)\n"
|
||||
" tmp = allocate((col, ), b.dtype)\n"
|
||||
" for i in range(row):\n"
|
||||
" for j in range(col):\n"
|
||||
" tmp[j] = a[j, j]\n"
|
||||
" b[i, j] = b[i, j] / tmp[j]\n"
|
||||
" for k in vectorize(col):\n"
|
||||
" inverse_0[k] = b[i, j] * a[j, k]\n"
|
||||
" for k in vectorize(j + 1):\n"
|
||||
" inverse_0[k] = (0.0)\n"
|
||||
" for k in vectorize(col):\n"
|
||||
" b[i, k] = b[i, k] - inverse_0[k]\n"
|
||||
" return b\n"}};
|
||||
|
||||
const std::map<std::string, std::string> kLUFuncStrMap = {{"lu_decomp",
|
||||
"def lu_decomp(a):\n"
|
||||
" out_0 = allocate(a.shape, a.dtype)\n"
|
||||
" out_1 = allocate(a.shape, a.dtype)\n"
|
||||
" for i in range(a.shape[0]):\n"
|
||||
" for j in range(a.shape[1]):\n"
|
||||
" if j > i:\n"
|
||||
" a[j, i] = a[j, i] / a[i, i]\n"
|
||||
" for k in range(a.shape[0]):\n"
|
||||
" for l in vectorize(a.shape[1]):\n"
|
||||
" out_0[k, l] = a[k, i]\n"
|
||||
" out_1[k, l] = out_0[k, l] * a[i, l]\n"
|
||||
" if k > i and l > i:\n"
|
||||
" a[k, l] = a[k, l] - out_1[k, l]\n"
|
||||
" return a\n"}};
|
||||
|
||||
const char kTrsmLAttrs[] =
|
||||
"{\"pragma_enable_reschedule\": false,"
|
||||
" \"enable_hoist_cond_write\": false,"
|
||||
" \"enable_approximate_read\": true,"
|
||||
" \"enable_post_poly_loop_partition\": false,"
|
||||
" \"enable_polytops\": \"always\"}";
|
||||
|
||||
const char kLUAttrs[] =
|
||||
"{\"pragma_enable_reschedule\": false,"
|
||||
" \"enable_hoist_cond_write\": false,"
|
||||
" \"enable_double_buffer\": false,"
|
||||
" \"enable_pre_poly_loop_partition\": false,"
|
||||
" \"enable_post_poly_loop_partition\": false,"
|
||||
" \"enable_to_three_address\": false,"
|
||||
" \"enable_polytops\": \"always\"}";
|
||||
} // namespace mindspore::graphkernel::expanders
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_CUS_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "common/graph_kernel/expanders/op_desc_registry.h"
|
||||
#include "common/graph_kernel/expanders/utils.h"
|
||||
#include "common/graph_kernel/expanders/custom_op_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel::expanders {
|
||||
class LU : public OpDesc {
|
||||
public:
|
||||
LU() {}
|
||||
~LU() = default;
|
||||
|
||||
protected:
|
||||
std::vector<int64_t> GetIndicesValue(int64_t u_left, int64_t u_right, int64_t v_left, int64_t v_right) {
|
||||
std::vector<int64_t> indices_value;
|
||||
for (int64_t u = u_left; u < u_right; ++u) {
|
||||
for (int64_t v = v_left; v < v_right; ++v) {
|
||||
(void)indices_value.emplace_back(u);
|
||||
(void)indices_value.emplace_back(v);
|
||||
}
|
||||
}
|
||||
return indices_value;
|
||||
}
|
||||
std::vector<int32_t> GetEyesValue(int64_t num) {
|
||||
std::vector<int32_t> eyes_value;
|
||||
for (int64_t i = 0; i < num; ++i) {
|
||||
for (int64_t j = 0; j < num; ++j) {
|
||||
if (i == j) {
|
||||
(void)eyes_value.emplace_back(1);
|
||||
} else {
|
||||
(void)eyes_value.emplace_back(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
return eyes_value;
|
||||
}
|
||||
std::vector<int32_t> GetRangeValue(int64_t num) {
|
||||
std::vector<int32_t> range_value;
|
||||
for (int i = 0; i < static_cast<int>(num); ++i) {
|
||||
(void)range_value.emplace_back(i);
|
||||
}
|
||||
return range_value;
|
||||
}
|
||||
NodePtrList Expand(const NodePtrList &inputs) override {
|
||||
auto input_x = inputs[0];
|
||||
auto num = input_x->shape[0];
|
||||
auto loop_count = static_cast<int64_t>(num / kBlock);
|
||||
std::vector<int64_t> strides{1, 1};
|
||||
|
||||
// lu dsl implementation
|
||||
for (int64_t i = 0; i < loop_count; ++i) {
|
||||
std::vector<int64_t> begin_1{i * kBlock, i * kBlock};
|
||||
std::vector<int64_t> end_1{(i + 1) * kBlock, (i + 1) * kBlock};
|
||||
auto stride_1 = gb.StridedSlice(input_x, begin_1, end_1, strides);
|
||||
std::string lu_func_type = kFuncType;
|
||||
std::string lu_func_name = kLUName;
|
||||
// get the name of lu via its attrs
|
||||
auto lu_iter = kLUFuncStrMap.find(lu_func_name);
|
||||
std::string lu_func_source_str = lu_iter->second;
|
||||
size_t lu_inplace_assign_output = 0;
|
||||
std::string lu_func_compile_attrs = kLUAttrs;
|
||||
|
||||
auto custom_lu_decomp_result =
|
||||
gb.Custom({stride_1}, {stride_1->shape, stride_1->type, stride_1->format}, lu_func_name, lu_func_type,
|
||||
lu_func_source_str, lu_inplace_assign_output, lu_func_compile_attrs);
|
||||
ShapeVector ind_shape{kBlock, kBlock, 2};
|
||||
std::vector<int64_t> ind_value = GetIndicesValue(i * kBlock, (i + 1) * kBlock, i * kBlock, (i + 1) * kBlock);
|
||||
|
||||
auto ind_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ind_shape, &ind_value[0], kNumberTypeInt64);
|
||||
auto first_indices = gb.Value(ind_tensor);
|
||||
input_x = gb.Emit("ScatterNdUpdate", {input_x, first_indices, custom_lu_decomp_result},
|
||||
{{"use_locking", MakeValue(false)}});
|
||||
if (i < loop_count - 1) {
|
||||
std::vector<int64_t> begin_2{i * kBlock, (i + 1) * kBlock};
|
||||
std::vector<int64_t> end_2{(i + 1) * kBlock, num};
|
||||
auto stride_2 = gb.StridedSlice(input_x, begin_1, end_1, strides);
|
||||
auto stride_3 = gb.StridedSlice(input_x, begin_2, end_2, strides);
|
||||
std::string trsmL_off_diag_func_type = kFuncType;
|
||||
std::string trsmL_off_diag_func_name = kTrsmName;
|
||||
// get the name of trsmL_off_diag via its attrs
|
||||
trsmL_off_diag_func_name = trsmL_off_diag_func_name + "L_N_U";
|
||||
auto trsmL_off_diag_iter = kTrsmFuncStrMap.find(trsmL_off_diag_func_name);
|
||||
std::string trsmL_off_diag_source_str = trsmL_off_diag_iter->second;
|
||||
size_t trsmL_off_diag_inplace_assign_output = 1;
|
||||
std::string trsmL_off_diag_compile_attrs = kTrsmLAttrs;
|
||||
auto custom_trsmL_off_diag_result =
|
||||
gb.Custom({stride_2, stride_3}, {stride_3->shape, stride_3->type, stride_3->format}, trsmL_off_diag_func_name,
|
||||
trsmL_off_diag_func_type, trsmL_off_diag_source_str, trsmL_off_diag_inplace_assign_output,
|
||||
trsmL_off_diag_compile_attrs);
|
||||
ShapeVector sec_indices_shape{kBlock, num - (i + 1) * kBlock, 2};
|
||||
std::vector<int64_t> sec_indicse_value = GetIndicesValue(i * kBlock, (i + 1) * kBlock, (i + 1) * kBlock, num);
|
||||
|
||||
auto sec_indices_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, sec_indices_shape,
|
||||
&sec_indicse_value[0], kNumberTypeInt64);
|
||||
auto sec_indices = gb.Value(sec_indices_tensor);
|
||||
input_x = gb.Emit("ScatterNdUpdate", {input_x, sec_indices, custom_trsmL_off_diag_result},
|
||||
{{"use_locking", MakeValue(false)}});
|
||||
|
||||
std::vector<int64_t> begin_3{(i + 1) * kBlock, i * kBlock};
|
||||
std::vector<int64_t> end_3{num, (i + 1) * kBlock};
|
||||
auto stride_4 = gb.StridedSlice(input_x, begin_1, end_1, strides);
|
||||
auto stride_5 = gb.StridedSlice(input_x, begin_3, end_3, strides);
|
||||
std::string trsmUT_func_type = kFuncType;
|
||||
std::string trsmUT_func_name = kTrsmName;
|
||||
// get the name of trsmUT via its attrs
|
||||
trsmUT_func_name = trsmUT_func_name + "U_T";
|
||||
auto trsmUT_iter = kTrsmFuncStrMap.find(trsmUT_func_name);
|
||||
std::string trsmUT_source_str = trsmUT_iter->second;
|
||||
size_t trsmUT_inplace_assign_output = 1;
|
||||
std::string trsmUT_compile_attrs = kTrsmLAttrs;
|
||||
auto custom_trsmUT_result =
|
||||
gb.Custom({stride_4, stride_5}, {stride_5->shape, stride_5->type, stride_5->format}, trsmUT_func_name,
|
||||
trsmUT_func_type, trsmUT_source_str, trsmUT_inplace_assign_output, trsmUT_compile_attrs);
|
||||
ShapeVector third_indices_shape{num - (i + 1) * kBlock, kBlock, 2};
|
||||
std::vector<int64_t> thi_indices_v = GetIndicesValue((i + 1) * kBlock, num, i * kBlock, (i + 1) * kBlock);
|
||||
|
||||
auto thi_indices_tensor =
|
||||
std::make_shared<tensor::Tensor>(kNumberTypeInt64, third_indices_shape, &thi_indices_v[0], kNumberTypeInt64);
|
||||
auto thi_indices = gb.Value(thi_indices_tensor);
|
||||
input_x =
|
||||
gb.Emit("ScatterNdUpdate", {input_x, thi_indices, custom_trsmUT_result}, {{"use_locking", MakeValue(false)}});
|
||||
|
||||
std::vector<int64_t> begin_4{(i + 1) * kBlock, (i + 1) * kBlock};
|
||||
std::vector<int64_t> end_4{num, num};
|
||||
std::vector<int64_t> begin_5{i * kBlock, (i + 1) * kBlock};
|
||||
std::vector<int64_t> end_5{(i + 1) * kBlock, num};
|
||||
auto stride_6 = gb.StridedSlice(input_x, begin_4, end_4, strides);
|
||||
auto stride_7 = gb.StridedSlice(input_x, begin_3, end_3, strides);
|
||||
auto stride_8 = gb.StridedSlice(input_x, begin_5, end_5, strides);
|
||||
// on ascend, matmul's inputs must be fp16
|
||||
stride_7 = gb.Cast(stride_7, kNumberTypeFloat16);
|
||||
stride_8 = gb.Cast(stride_8, kNumberTypeFloat16);
|
||||
auto matmul_stride_7_8 = gb.MatMul(stride_7, stride_8);
|
||||
matmul_stride_7_8 = gb.Cast(matmul_stride_7_8, kNumberTypeFloat32);
|
||||
auto final_update = gb.Sub(stride_6, matmul_stride_7_8);
|
||||
ShapeVector final_indices_shape{num - (i + 1) * kBlock, num - (i + 1) * kBlock, 2};
|
||||
std::vector<int64_t> final_indicse_value = GetIndicesValue((i + 1) * kBlock, num, (i + 1) * kBlock, num);
|
||||
|
||||
auto final_indices_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, final_indices_shape,
|
||||
&final_indicse_value[0], kNumberTypeInt64);
|
||||
auto f_indices = gb.Value(final_indices_tensor);
|
||||
input_x = gb.Emit("ScatterNdUpdate", {input_x, f_indices, final_update}, {{"use_locking", MakeValue(false)}});
|
||||
}
|
||||
}
|
||||
ShapeVector eyes_shape{num, num};
|
||||
std::vector<int32_t> eyes_value = GetEyesValue(num);
|
||||
auto eyes_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, eyes_shape, &eyes_value[0], kNumberTypeInt32);
|
||||
auto eyes = gb.Value(eyes_tensor);
|
||||
auto eyes_cnode = gb.Reshape(eyes, eyes_shape);
|
||||
|
||||
ShapeVector r_shape{num};
|
||||
std::vector<int32_t> r_value = GetRangeValue(num);
|
||||
auto range_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, r_shape, &r_value[0], kNumberTypeInt32);
|
||||
auto range = gb.Value(range_tensor);
|
||||
auto range_cnode = gb.Reshape(range, r_shape);
|
||||
return {input_x, range_cnode, eyes_cnode};
|
||||
}
|
||||
};
|
||||
EXPANDER_OP_DESC_REGISTER("LU", LU);
|
||||
} // namespace mindspore::graphkernel::expanders
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore.scipy.ops import LU
|
||||
|
||||
|
||||
class NetLU(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetLU, self).__init__()
|
||||
self.ops = LU()
|
||||
|
||||
def construct(self, a):
|
||||
return self.ops(a)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lu():
|
||||
"""
|
||||
Feature: lu op use custom compile test on graph mode.
|
||||
Description: test lu op on graph mode
|
||||
Expectation: the result equal to expect.
|
||||
"""
|
||||
num = 32
|
||||
one_tensor = np.zeros((num, num))
|
||||
for i in range(num):
|
||||
for j in range(num):
|
||||
one_tensor[i, j] = min(min(i, j), 8) + 1
|
||||
upper_matrix = np.triu(one_tensor).astype(np.float32)
|
||||
lower_matrix = np.tril(np.ones((num, num))).astype(np.float32)
|
||||
input1 = np.dot(lower_matrix, upper_matrix)
|
||||
expect = upper_matrix + lower_matrix - np.eye(num)
|
||||
real_input = Tensor(input1, mstype.float32)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = NetLU()
|
||||
result = net(real_input)
|
||||
rtol = 0.001
|
||||
atol = 0.001
|
||||
assert np.allclose(result[0].asnumpy(), expect, rtol, atol, equal_nan=True)
|
Loading…
Reference in New Issue