forked from mindspore-Ecosystem/mindspore
!30391 add julia cache and row-major api
Merge pull request !30391 from r1chardf1d0/master
This commit is contained in:
commit
0ade79cb84
|
@ -24,6 +24,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <thread>
|
||||
#include <condition_variable>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
|
@ -272,17 +273,32 @@ class JuliaAPI {
|
|||
}
|
||||
|
||||
bool RunJuliaKernel() {
|
||||
// include julia file
|
||||
JlEvalString("Base.include(Main, \"" + file_ + "\")");
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
// using julia module
|
||||
JlEvalString("using Main." + module_);
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
jl_module_t *jmod = reinterpret_cast<jl_module_t *>(JlEvalString("Main." + module_));
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
// get julia function from module
|
||||
jl_function_t *jfunc = JlGetFunction(jmod, func_);
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
if (!jl_file_caches_.count(file_)) {
|
||||
// include julia file
|
||||
JlEvalString("Base.include(Main, \"" + file_ + "\")");
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
jl_file_caches_.insert(file_);
|
||||
}
|
||||
jl_module_t *jmod = nullptr;
|
||||
if (!jl_module_caches_.count(file_ + module_)) {
|
||||
// using julia module
|
||||
JlEvalString("using Main." + module_);
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
jmod = reinterpret_cast<jl_module_t *>(JlEvalString("Main." + module_));
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
jl_module_caches_[file_ + module_] = jmod;
|
||||
} else {
|
||||
jmod = jl_module_caches_[file_ + module_];
|
||||
}
|
||||
jl_function_t *jfunc = nullptr;
|
||||
if (!jl_file_caches_.count(file_ + module_ + func_)) {
|
||||
// get julia function from module
|
||||
jfunc = JlGetFunction(jmod, func_);
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
jl_function_caches_[file_ + module_ + func_] = jfunc;
|
||||
} else {
|
||||
jfunc = jl_function_caches_[file_ + module_ + func_];
|
||||
}
|
||||
// convert kernel inputs to julia type
|
||||
std::vector<jl_value_t *> args(nparam_);
|
||||
for (int i = 0; i < nparam_; i++) {
|
||||
|
@ -383,6 +399,11 @@ class JuliaAPI {
|
|||
std::vector<int64_t *> shapes_;
|
||||
std::vector<const char *> dtypes_;
|
||||
|
||||
// julia cache
|
||||
std::unordered_set<std::string> jl_file_caches_;
|
||||
std::unordered_map<std::string, jl_module_t *> jl_module_caches_;
|
||||
std::unordered_map<std::string, jl_function_t *> jl_function_caches_;
|
||||
|
||||
// about julia shared library
|
||||
void *handle_;
|
||||
jl_value_t *(*jl_eval_string_)(const char *);
|
||||
|
|
|
@ -48,61 +48,81 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
2. A TBE operator implementation function.
|
||||
3. A pure python function
|
||||
|
||||
- str: If func is of str type, then str should be a path of binary file along with a function name.
|
||||
This could only be used when func_type is "aot". Currently "aot" supports GPU/CPU(linux only) platform.
|
||||
"aot" means ahead of time, in which case Custom directly launches user defined "xxx.so" file as an
|
||||
operator. Users need to compile a handwriting "xxx.cu"/"xxx.cc" file into "xxx.so" ahead of time,
|
||||
and offer the path of the file along with a function name.
|
||||
- str: If func is of str type, then str should be a path of file along with a function name.
|
||||
This could be used when func_type is "aot" or "julia".
|
||||
|
||||
- "xxx.so" file generation:
|
||||
1. for "aot":
|
||||
Currently "aot" supports GPU/CPU(linux only) platform.
|
||||
"aot" means ahead of time, in which case Custom directly launches user defined "xxx.so" file as an
|
||||
operator. Users need to compile a handwriting "xxx.cu"/"xxx.cc" file into "xxx.so" ahead of time,
|
||||
and offer the path of the file along with a function name.
|
||||
|
||||
1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"), use nvcc command to compile
|
||||
it.(ex. "nvcc --shared -Xcompiler -fPIC -o add.so add.cu")
|
||||
- "xxx.so" file generation:
|
||||
|
||||
2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"), use g++/gcc command to compile
|
||||
it.(ex. "g++ --shared -fPIC -o add.so add.cc")
|
||||
1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"), use nvcc command to compile
|
||||
it.(ex. "nvcc --shared -Xcompiler -fPIC -o add.so add.cu")
|
||||
|
||||
- Define a "xxx.cc"/"xxx.cu" file:
|
||||
2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"), use g++/gcc command to
|
||||
compile it.(ex. "g++ --shared -fPIC -o add.so add.cc")
|
||||
|
||||
"aot" is a cross-platform identity. The functions defined in "xxx.cc" or "xxx.cu" share the same args.
|
||||
Typically, the function should be as:
|
||||
- Define a "xxx.cc"/"xxx.cu" file:
|
||||
|
||||
.. code-block::
|
||||
"aot" is a cross-platform identity. The functions defined in "xxx.cc" or "xxx.cu" share the same args.
|
||||
Typically, the function should be as:
|
||||
|
||||
int func(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
|
||||
void *stream, void *extra)
|
||||
.. code-block::
|
||||
|
||||
Parameters:
|
||||
int func(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
|
||||
void *stream, void *extra)
|
||||
|
||||
- nparam(int): total number of inputs plus outputs; suppose the operator has 2 inputs and 3 outputs,
|
||||
then nparam=5
|
||||
- params(void \*\*): a pointer to the array of inputs and outputs' pointer; the pointer type of inputs
|
||||
and outputs is void \* ; suppose the operator has 2 inputs and 3 outputs, then the first input's
|
||||
pointer is params[0] and the second output's pointer is params[3]
|
||||
- ndims(int \*): a pointer to the array of inputs and outputs' dimension num; suppose params[i] is a
|
||||
1024x1024 tensor and params[j] is a 77x83x4 tensor, then ndims[i]=2, ndims[j]=3.
|
||||
- shapes(int64_t \*\*): a pointer to the array of inputs and outputs' shapes(int64_t \*); the ith
|
||||
input's jth dimension's size is shapes[i][j](0<=j<ndims[i]); suppose params[i] is a 2x3 tensor and
|
||||
params[j] is a 3x3x4 tensor, then shapes[i][0]=2, shapes[j][2]=4.
|
||||
- dtypes(const char \*\*): a pointer to the array of inputs and outputs' types(const char \*);
|
||||
(ex. "float32", "float16", "float", "float64", "int", "int8", "int16", "int32", "int64", "uint",
|
||||
"uint8", "uint16", "uint32", "uint64", "bool")
|
||||
- stream(void \*): stream pointer, only used in cuda file
|
||||
- extra(void \*): used for further extension
|
||||
Parameters:
|
||||
|
||||
Return Value(int):
|
||||
- nparam(int): total number of inputs plus outputs; suppose the operator has 2 inputs and 3 outputs,
|
||||
then nparam=5
|
||||
- params(void \*\*): a pointer to the array of inputs and outputs' pointer; the pointer type of inputs
|
||||
and outputs is void \* ; suppose the operator has 2 inputs and 3 outputs, then the first input's
|
||||
pointer is params[0] and the second output's pointer is params[3]
|
||||
- ndims(int \*): a pointer to the array of inputs and outputs' dimension num; suppose params[i] is a
|
||||
1024x1024 tensor and params[j] is a 77x83x4 tensor, then ndims[i]=2, ndims[j]=3.
|
||||
- shapes(int64_t \*\*): a pointer to the array of inputs and outputs' shapes(int64_t \*); the ith
|
||||
input's jth dimension's size is shapes[i][j](0<=j<ndims[i]); suppose params[i] is a 2x3 tensor and
|
||||
params[j] is a 3x3x4 tensor, then shapes[i][0]=2, shapes[j][2]=4.
|
||||
- dtypes(const char \*\*): a pointer to the array of inputs and outputs' types(const char \*);
|
||||
(ex. "float32", "float16", "float", "float64", "int", "int8", "int16", "int32", "int64", "uint",
|
||||
"uint8", "uint16", "uint32", "uint64", "bool")
|
||||
- stream(void \*): stream pointer, only used in cuda file
|
||||
- extra(void \*): used for further extension
|
||||
|
||||
- 0: MindSpore will continue to run if this aot kernel is successfully executed
|
||||
- others: MindSpore will raise exception and exit
|
||||
Return Value(int):
|
||||
|
||||
Examples: see details in tests/st/ops/graph_kernel/custom/aot_test_files/
|
||||
- 0: MindSpore will continue to run if this aot kernel is successfully executed
|
||||
- others: MindSpore will raise exception and exit
|
||||
|
||||
- Use it in Custom:
|
||||
Examples: see details in tests/st/ops/graph_kernel/custom/aot_test_files/
|
||||
|
||||
.. code-block::
|
||||
- Use it in Custom:
|
||||
|
||||
Custom(func="{dir_path}/{file_name}:{func_name}",...)
|
||||
(ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_dtype=mstype.float32))
|
||||
.. code-block::
|
||||
|
||||
Custom(func="{dir_path}/{file_name}:{func_name}",...)
|
||||
(ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_dtype=mstype.float32,
|
||||
"aot")
|
||||
|
||||
2. for "julia":
|
||||
Currently "julia" supports CPU(linux only) platform.
|
||||
For julia use JIT compiler, and julia support c api to call julia code.
|
||||
The Custom can directly launches user defined "xxx.jl" file as an operator.
|
||||
Users need to write a "xxx.jl" file which include modules and functions,
|
||||
and offer the path of the file along with a module name and function name.
|
||||
|
||||
Examples: see details in tests/st/ops/graph_kernel/custom/julia_test_files/
|
||||
|
||||
- Use it in Custom:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Custom(func="{dir_path}/{file_name}:{module_name}:{func_name}",...)
|
||||
(ex. Custom(func="./add.jl:Add:add", out_shape=[1], out_dtype=mstype.float32, "julia")
|
||||
|
||||
out_shape (Union[function, list, tuple]): The output shape infer function or the value of output shape of
|
||||
`func`.
|
||||
|
@ -127,6 +147,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
- "tbe": supports ["Ascend"].
|
||||
- "aot": supports ["GPU", "CPU"].
|
||||
- "pyfunc": supports ["CPU"].
|
||||
- "julia": supports ["CPU"].
|
||||
|
||||
bprop (function): The back propagation function of `func`. Default: None.
|
||||
reg_info (Union[str, dict, list, tuple]): Represents the registration information(reg info) of `func` with
|
||||
|
@ -273,6 +294,22 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
... self.func = ops.Custom(func_multi_output, lambda x, _: (x, x), lambda x, _: (x, x), "pyfunc")
|
||||
... def construct(self, x1, x2):
|
||||
... return self.func(x1, x2)
|
||||
>>>
|
||||
>>> # Example, func_type = "julia"
|
||||
>>> # julia code:
|
||||
>>> # add.jl
|
||||
>>> # module Add
|
||||
>>> # function add(x, y, z)
|
||||
>>> # z .= x + y
|
||||
>>> # return z
|
||||
>>> # end
|
||||
>>> # end
|
||||
>>> class JULIASingleOutputNet(Cell):
|
||||
... def __init__(self, out_shapes, out_types):
|
||||
... super(JULIASingleOutputNet, self).__init__()
|
||||
... self.program = ops.Custom("./add.jl:Add:add", out_shapes, out_types, "julia")
|
||||
... def construct(self, x, y):
|
||||
... return self.program(x, y)
|
||||
"""
|
||||
|
||||
registered_func = {}
|
||||
|
|
|
@ -21,12 +21,11 @@ from mindspore import context, Tensor
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import DataType, CustomRegOp
|
||||
|
||||
|
||||
class JuliaSingleOutputNet(Cell):
|
||||
class JuliaTwoInputsNet(Cell):
|
||||
def __init__(self, func, out_shapes, out_types, reg=None):
|
||||
super(JuliaSingleOutputNet, self).__init__()
|
||||
super(JuliaTwoInputsNet, self).__init__()
|
||||
|
||||
self.program = ops.Custom(func, out_shapes, out_types, "julia", reg_info=reg)
|
||||
|
||||
|
@ -34,6 +33,16 @@ class JuliaSingleOutputNet(Cell):
|
|||
return self.program(x, y)
|
||||
|
||||
|
||||
class JuliaOneInputNet(Cell):
|
||||
def __init__(self, func, out_shapes, out_types, reg=None):
|
||||
super(JuliaOneInputNet, self).__init__()
|
||||
|
||||
self.program = ops.Custom(func, out_shapes, out_types, "julia", reg_info=reg)
|
||||
|
||||
def construct(self, x):
|
||||
return self.program(x)
|
||||
|
||||
|
||||
def add(x, y):
|
||||
"""
|
||||
function add for benchmark
|
||||
|
@ -48,26 +57,76 @@ def sub(x, y):
|
|||
return x - y
|
||||
|
||||
|
||||
def julia_single_output(func_name, bench, reg):
|
||||
def matmul(x, y):
|
||||
"""
|
||||
function matmul for benchmark
|
||||
"""
|
||||
return np.matmul(x, y)
|
||||
|
||||
|
||||
def reducesum(x, axis=0, keepdims=True):
|
||||
return np.sum(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
def multiout(a, b):
|
||||
return a + b, a - b
|
||||
|
||||
|
||||
def julia_elemwise_test(func_name, bench):
|
||||
shape = (4, 5)
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
func_path = os.path.dirname(os.path.abspath(__file__)) + "/julia_test_files/"
|
||||
try:
|
||||
test = JuliaSingleOutputNet(func_path + func_name, (shape,), (mstype.float32,), reg)
|
||||
test = JuliaTwoInputsNet(func_path + func_name, (shape,), (mstype.float32,))
|
||||
output = test(Tensor(input_x), Tensor(input_y))[0]
|
||||
except Exception as e:
|
||||
raise e
|
||||
assert np.allclose(bench(input_x, input_y), output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
cpu_info = CustomRegOp() \
|
||||
.input(0, "x1") \
|
||||
.input(1, "x2") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.target("CPU") \
|
||||
.get_op_info()
|
||||
def julia_matmul_test(func_name, bench):
|
||||
shape1 = (2, 3)
|
||||
shape2 = (3, 4)
|
||||
shape3 = (2, 4)
|
||||
input_x = np.random.normal(0, 1, shape1).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, shape2).astype(np.float32)
|
||||
func_path = os.path.dirname(os.path.abspath(__file__)) + "/julia_test_files/"
|
||||
try:
|
||||
test = JuliaTwoInputsNet(func_path + func_name, (shape3,), (mstype.float32,))
|
||||
output = test(Tensor(input_x), Tensor(input_y))[0]
|
||||
except Exception as e:
|
||||
raise e
|
||||
assert np.allclose(bench(input_x, input_y), output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
def julia_reducesum_test(func_name, bench):
|
||||
shape1 = (2, 3, 4)
|
||||
input_x = np.random.normal(0, 1, shape1).astype(np.float32)
|
||||
expect = bench(input_x, 1)
|
||||
func_path = os.path.dirname(os.path.abspath(__file__)) + "/julia_test_files/"
|
||||
try:
|
||||
test = JuliaOneInputNet(func_path + func_name, (expect.shape,), (mstype.float32,))
|
||||
output = test(Tensor(input_x))[0]
|
||||
except Exception as e:
|
||||
raise e
|
||||
assert np.allclose(expect, output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
def julia_multiout_test(func_name, bench):
|
||||
shape = (4, 5)
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
func_path = os.path.dirname(os.path.abspath(__file__)) + "/julia_test_files/"
|
||||
try:
|
||||
test = JuliaTwoInputsNet(func_path + func_name, (shape, shape,), (mstype.float32, mstype.float32,))
|
||||
output1 = test(Tensor(input_x), Tensor(input_y))[0]
|
||||
output2 = test(Tensor(input_x), Tensor(input_y))[1]
|
||||
except Exception as e:
|
||||
raise e
|
||||
expect1, expect2 = bench(input_x, input_y)
|
||||
assert np.allclose(expect1, output1.asnumpy(), 0.001, 0.001)
|
||||
assert np.allclose(expect2, output2.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
|
@ -84,7 +143,7 @@ def test_julia_single_output_cpu_add():
|
|||
pass
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
julia_single_output("add.jl:Add:foo!", add, cpu_info)
|
||||
julia_elemwise_test("add.jl:Add:foo!", add)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
|
@ -101,4 +160,55 @@ def test_julia_single_output_cpu_sub():
|
|||
pass
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
julia_single_output("sub.jl:Sub:foo!", sub, cpu_info)
|
||||
julia_elemwise_test("sub.jl:Sub:foo!", sub)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_julia_single_output_cpu_matmul():
|
||||
"""
|
||||
Feature: custom julia operator, multiple inputs, single output, CPU, GRAPH_MODE
|
||||
Description: pre-write xxx.jl, custom operator launches xxx.jl
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
system = platform.system()
|
||||
if system != 'Linux':
|
||||
pass
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
julia_matmul_test("matmul.jl:Matmul:foo!", matmul)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_julia_single_output_cpu_reducesum():
|
||||
"""
|
||||
Feature: custom julia operator, multiple inputs, single output, CPU, GRAPH_MODE
|
||||
Description: pre-write xxx.jl, custom operator launches xxx.jl
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
system = platform.system()
|
||||
if system != 'Linux':
|
||||
pass
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
julia_reducesum_test("reducesum.jl:ReduceSum:foo!", reducesum)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_julia_multi_output_cpu():
|
||||
"""
|
||||
Feature: custom julia operator, multiple inputs, multi output, CPU, GRAPH_MODE
|
||||
Description: pre-write xxx.jl, custom operator launches xxx.jl
|
||||
Expectation: nn result matches numpy result
|
||||
"""
|
||||
system = platform.system()
|
||||
if system != 'Linux':
|
||||
pass
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
julia_multiout_test("multi_output.jl:MultiOutput:foo!", multiout)
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# matmul.jl
|
||||
module Matmul
|
||||
export foo!
|
||||
include("row_major.jl")
|
||||
|
||||
function gemm(x, y, z)
|
||||
@inbounds @fastmath for m ∈ axes(x, 1), n ∈ axes(y, 2)
|
||||
zmn = zero(eltype(z))
|
||||
for k ∈ axes(x, 2)
|
||||
zmn += x[m, k] * y[k, n]
|
||||
end
|
||||
z[m, n] = zmn
|
||||
end
|
||||
return z
|
||||
end
|
||||
|
||||
# z is output, should use . to inplace
|
||||
# julia array is column-major, numpy aray is row-major
|
||||
# user should change julia or numpy's layout to keep same behavior
|
||||
#= EXAMPLE
|
||||
A[2,3] B[3,4] C[2,4]
|
||||
NUMPY:
|
||||
[[1, 2, 3] [[1, 2, 3, 4] [[38, 44, 50, 56]
|
||||
[4, 5, 6]] [5, 6, 7, 8] [83, 98, 113,128]]
|
||||
[9,10,11,12]]
|
||||
JULIA:
|
||||
inputs read numpy data from memory:
|
||||
[[1, 3, 5] [[1, 4, 7,10]
|
||||
[2, 4, 6]] [2, 5, 8,11]
|
||||
[3, 6, 9,12]]
|
||||
inputs after reshape(reverse(shape)):
|
||||
[[1, 4] [[1, 5, 9]
|
||||
[2, 5] [2, 6,10]
|
||||
[3, 6]] [3, 7,11]
|
||||
[4, 8,12]]
|
||||
inputs after transpose/permutedims:
|
||||
[[1, 2, 3] [[1, 2, 3, 4] [[38, 44, 50, 56]
|
||||
[4, 5, 6]] [5, 6, 7, 8] [83, 98, 113,128]]
|
||||
[9,10,11,12]]
|
||||
output after transpose/permutedims:
|
||||
[[38, 83]
|
||||
[44, 98]
|
||||
[50,113]
|
||||
[56,128]
|
||||
output after reshape:
|
||||
[[38, 50, 83, 113]
|
||||
[44, 56, 98, 128]]
|
||||
output read numpy data from memory:
|
||||
[[38, 44, 50, 56]
|
||||
[83, 98,113, 128]]
|
||||
=#
|
||||
function foo!(x, y, z)
|
||||
x = change_input_to_row_major(x)
|
||||
y = change_input_to_row_major(y)
|
||||
z .= gemm(x, y, z)
|
||||
z .= change_output_to_row_major(z)
|
||||
end
|
||||
|
||||
end
|
|
@ -0,0 +1,32 @@
|
|||
# matmul_loop_vectorization.jl
|
||||
module Matmul
|
||||
export foo!
|
||||
include("row_major.jl")
|
||||
|
||||
# if dont have LoopVectorization pkg, install it as below
|
||||
# import Pkg
|
||||
# Pkg.add("LoopVectorization")
|
||||
using LoopVectorization
|
||||
|
||||
function gemmavx(x, y, z)
|
||||
@turbo for m ∈ axes(x, 1), n ∈ axes(y, 2)
|
||||
zmn = zero(eltype(z))
|
||||
for k ∈ axes(x, 2)
|
||||
zmn += x[m, k] * y[k, n]
|
||||
end
|
||||
z[m, n] = zmn
|
||||
end
|
||||
return z
|
||||
end
|
||||
|
||||
# z is output, should use . to inplace
|
||||
# julia array is column-major, numpy aray is row-major
|
||||
# user should transpose julia or numpy's array to keep same behavior
|
||||
function foo!(x, y, z)
|
||||
x = change_input_to_row_major(x)
|
||||
y = change_input_to_row_major(y)
|
||||
z .= gemmavx(x, y, z)
|
||||
z .= change_output_to_row_major(z)
|
||||
end
|
||||
|
||||
end
|
|
@ -0,0 +1,11 @@
|
|||
# multi_output.jl
|
||||
module MultiOutput
|
||||
export foo!
|
||||
|
||||
# inputs: a, b; outputs: c, d
|
||||
function foo!(a, b, c, d)
|
||||
c .= a + b
|
||||
d .= a - b
|
||||
end
|
||||
|
||||
end
|
|
@ -0,0 +1,13 @@
|
|||
# reducesum.jl
|
||||
module ReduceSum
|
||||
export foo!
|
||||
include("row_major.jl")
|
||||
|
||||
function foo!(x, y)
|
||||
x = change_input_to_row_major(x)
|
||||
# julia axis = 2 equals numpy axis = 1
|
||||
y .= sum(x, dims=2)
|
||||
y .= change_output_to_row_major(y)
|
||||
end
|
||||
|
||||
end
|
|
@ -0,0 +1,10 @@
|
|||
export change_input_to_row_major
|
||||
export change_output_to_row_major
|
||||
|
||||
function change_input_to_row_major(x)
|
||||
return permutedims(reshape(x, reverse(size(x))), length(size(x)):-1:1)
|
||||
end
|
||||
|
||||
function change_output_to_row_major(x)
|
||||
return reshape(permutedims(x, length(size(x)):-1:1), size(x))
|
||||
end
|
Loading…
Reference in New Issue