forked from jittor/jittor
fix doc
This commit is contained in:
parent
93ca5e9525
commit
5bcfe0db60
|
@ -15,4 +15,4 @@ echo "[jittor path] $jittor_path"
|
|||
|
||||
export PYTHONPATH=$jittor_path/python
|
||||
cd $bpath
|
||||
sphinx-autobuild -b html source build
|
||||
sphinx-autobuild -b html source build -H 0.0.0.0 -p 8890
|
||||
|
|
|
@ -63,7 +63,9 @@ single_log_capture = None
|
|||
|
||||
class log_capture_scope(_call_no_record_scope):
|
||||
"""log capture scope
|
||||
example:
|
||||
|
||||
example::
|
||||
|
||||
with jt.log_capture_scope(log_v=0) as logs:
|
||||
LOG.v("...")
|
||||
print(logs)
|
||||
|
@ -95,7 +97,9 @@ class log_capture_scope(_call_no_record_scope):
|
|||
|
||||
class profile_scope(_call_no_record_scope):
|
||||
""" profile scope
|
||||
example:
|
||||
|
||||
example::
|
||||
|
||||
with jt.profile_scope() as report:
|
||||
......
|
||||
print(report)
|
||||
|
@ -182,7 +186,7 @@ def norm(x, k, dim):
|
|||
if k==1:
|
||||
return x.abs().sum(dim)
|
||||
if k==2:
|
||||
return x.sqr().sum(dim).sqrt()
|
||||
return (x**2).sum(dim).sqrt()
|
||||
Var.norm = norm
|
||||
|
||||
origin_reshape = reshape
|
||||
|
@ -474,7 +478,9 @@ def make_module(func, exec_n_args=1):
|
|||
|
||||
def dirty_fix_pytorch_runtime_error():
|
||||
''' This funtion should be called before pytorch.
|
||||
Example:
|
||||
|
||||
Example::
|
||||
|
||||
import jittor as jt
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
|
|
|
@ -37,7 +37,8 @@ class Dataset(object):
|
|||
'''
|
||||
base class for reading data
|
||||
|
||||
Example:
|
||||
Example::
|
||||
|
||||
class YourDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -75,17 +76,16 @@ class Dataset(object):
|
|||
return (self.total_len-1) // self.batch_size + 1
|
||||
|
||||
def set_attrs(self, **kw):
|
||||
'''set attributes of dataset, equivalent to setattr
|
||||
'''set attributes of dataset, equivalent to set_attr
|
||||
|
||||
Attrs:
|
||||
batch_size(int): batch size, default 16.
|
||||
totol_len(int): totol lenght.
|
||||
shuffle(bool): shuffle at each epoch, default False.
|
||||
drop_last(bool): if true, the last batch of dataset
|
||||
might smaller than batch_size, default True.
|
||||
num_workers: number of workers for loading data
|
||||
buffer_size: buffer size for each worker in bytes,
|
||||
default(512MB).
|
||||
|
||||
* batch_size(int): batch size, default 16.
|
||||
* totol_len(int): totol lenght.
|
||||
* shuffle(bool): shuffle at each epoch, default False.
|
||||
* drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
|
||||
* num_workers: number of workers for loading data
|
||||
* buffer_size: buffer size for each worker in bytes, default(512MB).
|
||||
'''
|
||||
for k,v in kw.items():
|
||||
assert hasattr(self, k), k
|
||||
|
@ -287,19 +287,22 @@ class Dataset(object):
|
|||
class ImageFolder(Dataset):
|
||||
"""A image classify dataset, load image and label from directory:
|
||||
|
||||
root/label1/img1.png
|
||||
root/label1/img2.png
|
||||
...
|
||||
root/label2/img1.png
|
||||
root/label2/img2.png
|
||||
...
|
||||
Args:
|
||||
root(string): Root directory path.
|
||||
* root/label1/img1.png
|
||||
* root/label1/img2.png
|
||||
* ...
|
||||
* root/label2/img1.png
|
||||
* root/label2/img2.png
|
||||
* ...
|
||||
|
||||
Attributes:
|
||||
classes(list): List of the class names.
|
||||
class_to_idx(dict): map from class_name to class_index.
|
||||
imgs(list): List of (image_path, class_index) tuples
|
||||
Args:
|
||||
|
||||
* root(string): Root directory path.
|
||||
|
||||
Attributes:
|
||||
|
||||
* classes(list): List of the class names.
|
||||
* class_to_idx(dict): map from class_name to class_index.
|
||||
* imgs(list): List of (image_path, class_index) tuples
|
||||
"""
|
||||
def __init__(self, root, transform=None):
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
|
|
@ -15,11 +15,11 @@ import numpy as np
|
|||
|
||||
class Optimizer(object):
|
||||
""" Basic class of Optimizer.
|
||||
Example:
|
||||
```
|
||||
optimizer = nn.SGD(model.parameters(), lr)
|
||||
optimizer.step(loss)
|
||||
```
|
||||
|
||||
Example::
|
||||
|
||||
optimizer = nn.SGD(model.parameters(), lr)
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, params, lr, param_sync_iter=10000):
|
||||
self.param_groups = []
|
||||
|
@ -35,15 +35,14 @@ class Optimizer(object):
|
|||
self.n_step = 0
|
||||
|
||||
def pre_step(self, loss):
|
||||
""" something should be done before step,
|
||||
such as calc gradients, mpi sync, and so on.
|
||||
Example:
|
||||
```
|
||||
class MyOptimizer(Optimizer):
|
||||
def step(self, loss):
|
||||
self.post_step(loss)
|
||||
...
|
||||
```
|
||||
""" something should be done before step, such as calc gradients, mpi sync, and so on.
|
||||
|
||||
Example::
|
||||
|
||||
class MyOptimizer(Optimizer):
|
||||
def step(self, loss):
|
||||
self.post_step(loss)
|
||||
...
|
||||
"""
|
||||
# clean prev grads
|
||||
params = []
|
||||
|
@ -92,11 +91,11 @@ class Optimizer(object):
|
|||
|
||||
class SGD(Optimizer):
|
||||
""" SGD Optimizer.
|
||||
Example:
|
||||
```
|
||||
optimizer = nn.SGD(model.parameters(), lr, momentum=0.9)
|
||||
optimizer.step(loss)
|
||||
```
|
||||
|
||||
Example::
|
||||
|
||||
optimizer = nn.SGD(model.parameters(), lr, momentum=0.9)
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False):
|
||||
super().__init__(params, lr)
|
||||
|
@ -134,11 +133,11 @@ class SGD(Optimizer):
|
|||
|
||||
class Adam(Optimizer):
|
||||
""" Adam Optimizer.
|
||||
Example:
|
||||
```
|
||||
optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999))
|
||||
optimizer.step(loss)
|
||||
```
|
||||
|
||||
Example::
|
||||
|
||||
optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999))
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
|
||||
super().__init__(params, lr)
|
||||
|
|
|
@ -27,10 +27,11 @@ class RandomCropAndResize:
|
|||
"""Random crop and resize the given PIL Image to given size.
|
||||
|
||||
Args:
|
||||
size(int or tuple): width and height of the output image
|
||||
scale(tuple): range of scale ratio of the area
|
||||
ratio(tuple): range of aspect ratio
|
||||
interpolation: Default: PIL.Image.BILINEAR
|
||||
|
||||
* size(int or tuple): width and height of the output image
|
||||
* scale(tuple): range of scale ratio of the area
|
||||
* ratio(tuple): range of aspect ratio
|
||||
* interpolation: Default: PIL.Image.BILINEAR
|
||||
"""
|
||||
def __init__(self, size, scale:tuple=(0.08, 1.0), ratio:tuple=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
|
||||
if isinstance(size, int):
|
||||
|
|
|
@ -32,28 +32,28 @@ struct ArgsortOp : Op {
|
|||
|
||||
compare(y[0], y[1]) && compare(y[1], y[2]) && ...
|
||||
|
||||
@param[in] x input var for sort
|
||||
* [in] x: input var for sort
|
||||
|
||||
@param[in] dim sort alone which dim
|
||||
* [in] dim: sort alone which dim
|
||||
|
||||
@param[in] dtype type of return indexes
|
||||
* [in] dtype: type of return indexes
|
||||
|
||||
@param[in] key code for sorted key
|
||||
* [in] key: code for sorted key
|
||||
|
||||
@param[in] compare code for compare
|
||||
* [in] compare: code for compare
|
||||
|
||||
@param[out] index index have the same size with sorted dim
|
||||
* [out] index: index have the same size with sorted dim
|
||||
|
||||
|
||||
Example
|
||||
```
|
||||
jt.sort([11,13,12])
|
||||
# return [0,2,1]
|
||||
jt.sort([11,13,12], key='-@x(i)')
|
||||
# return [1,2,0]
|
||||
jt.sort([11,13,12], key='@x(i)<@x(j)')
|
||||
# return [0,2,1]
|
||||
```
|
||||
Example::
|
||||
|
||||
jt.sort([11,13,12])
|
||||
# return [0,2,1]
|
||||
jt.sort([11,13,12], key='-@x(i)')
|
||||
# return [1,2,0]
|
||||
jt.sort([11,13,12], key='@x(i)<@x(j)')
|
||||
# return [0,2,1]
|
||||
|
||||
*/
|
||||
// @attrs(multiple_outputs)
|
||||
ArgsortOp(Var* x, int dim=-1, bool descending=false, NanoString dtype=ns_int32);
|
||||
|
|
|
@ -18,7 +18,7 @@ struct CandidateOp : Op {
|
|||
/**
|
||||
Candidate Operator Perform an indirect candidate filter by given a fail condition.
|
||||
|
||||
x is input, y is output index, satisfy:
|
||||
x is input, y is output index, satisfy::
|
||||
|
||||
not fail_cond(y[0], y[1]) and
|
||||
not fail_cond(y[0], y[2]) and not fail_cond(y[1], y[2]) and
|
||||
|
@ -27,35 +27,33 @@ struct CandidateOp : Op {
|
|||
|
||||
Where m is number of selected candidates.
|
||||
|
||||
Pseudo code:
|
||||
```
|
||||
y = []
|
||||
for i in range(n):
|
||||
pass = True
|
||||
for j in y:
|
||||
if (@fail_cond):
|
||||
pass = false
|
||||
break
|
||||
if (pass):
|
||||
y.append(i)
|
||||
return y
|
||||
```
|
||||
Pseudo code::
|
||||
|
||||
y = []
|
||||
for i in range(n):
|
||||
pass = True
|
||||
for j in y:
|
||||
if (@fail_cond):
|
||||
pass = false
|
||||
break
|
||||
if (pass):
|
||||
y.append(i)
|
||||
return y
|
||||
|
||||
@param[in] x input var for filter
|
||||
* [in] x: input var for filter
|
||||
|
||||
@param[in] fail_cond code for fail condition
|
||||
* [in] fail_cond: code for fail condition
|
||||
|
||||
@param[in] dtype type of return indexes
|
||||
* [in] dtype: type of return indexes
|
||||
|
||||
@param[out] index .
|
||||
* [out] index: .
|
||||
|
||||
Example
|
||||
```
|
||||
jt.candidate(jt.random(100,2), '(@x(j,0)>@x(i,0))or(@x(j,1)>@x(i,1))')
|
||||
# return y satisfy:
|
||||
# x[y[0], 0] <= x[y[1], 0] and x[y[1], 0] <= x[y[2], 0] and ... and x[y[m-2], 0] <= x[y[m-1], 0] and
|
||||
# x[y[0], 1] <= x[y[1], 1] and x[y[1], 1] <= x[y[2], 1] and ... and x[y[m-2], 1] <= x[y[m-1], 1]
|
||||
```
|
||||
Example::
|
||||
|
||||
jt.candidate(jt.random(100,2), '(@x(j,0)>@x(i,0))or(@x(j,1)>@x(i,1))')
|
||||
# return y satisfy:
|
||||
# x[y[0], 0] <= x[y[1], 0] and x[y[1], 0] <= x[y[2], 0] and ... and x[y[m-2], 0] <= x[y[m-1], 0] and
|
||||
# x[y[0], 1] <= x[y[1], 1] and x[y[1], 1] <= x[y[2], 1] and ... and x[y[m-2], 1] <= x[y[m-1], 1]
|
||||
*/
|
||||
CandidateOp(Var* x, string&& fail_cond, NanoString dtype=ns_int32);
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -22,20 +22,20 @@ struct CodeOp : Op {
|
|||
|
||||
----------------
|
||||
|
||||
@param[in] shape the output shape, a integer array
|
||||
* [in] shape: the output shape, a integer array
|
||||
|
||||
@param[in] dtype the output data type
|
||||
* [in] dtype: the output data type
|
||||
|
||||
@param[in] inputs A list of input jittor Vars
|
||||
* [in] inputs: A list of input jittor Vars
|
||||
|
||||
@param[in] cpu_src cpu source code string, buildin value:
|
||||
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
|
||||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
* [in] cpu_src: cpu source code string, buildin value:
|
||||
|
||||
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
|
||||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
|
||||
@param[in] cpu_grad_src A list of string,
|
||||
cpu source code string for gradient, represents gradiant
|
||||
for each inputm buildin value, buildin value:
|
||||
* [in] cpu_grad_src: A list of string, cpu source code string for gradient, represents gradiant for each inputm buildin value, buildin value:
|
||||
|
||||
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
|
||||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
|
@ -43,171 +43,164 @@ struct CodeOp : Op {
|
|||
* pout, pout_shape{y}, pout_stride{y}, pout_type, pout_p, @pout(...)
|
||||
* dout, dout_shape{y}, dout_stride{y}, dout_type, dout_p, @dout(...)
|
||||
|
||||
@param[in] cpu_header cpu header code string.
|
||||
* [in] cpu_header: cpu header code string.
|
||||
|
||||
@param[in] cuda_src cuda source code string.
|
||||
* [in] cuda_src: cuda source code string.
|
||||
|
||||
@param[in] cuda_grad_src A list of string.
|
||||
* [in] cuda_grad_src: A list of string.
|
||||
|
||||
@param[in] cuda_header cuda header code string.
|
||||
* [in] cuda_header: cuda header code string.
|
||||
|
||||
----------------
|
||||
|
||||
Example-1:
|
||||
Example-1::
|
||||
|
||||
```
|
||||
a = jt.random([10])
|
||||
b = jt.code(a.shape, "float32", [a],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
''',
|
||||
cpu_grad_src = ['''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @dout(i)*@in0(i)*4;
|
||||
'''])
|
||||
```
|
||||
a = jt.random([10])
|
||||
b = jt.code(a.shape, "float32", [a],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
''',
|
||||
cpu_grad_src = ['''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @dout(i)*@in0(i)*4;
|
||||
'''])
|
||||
|
||||
Example-2:
|
||||
```
|
||||
a = jt.array([3,2,1])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cpu_header="""
|
||||
#include <algorithm>
|
||||
@alias(a, in0)
|
||||
@alias(b, out)
|
||||
""",
|
||||
cpu_src="""
|
||||
for (int i=0; i<a_shape0; i++)
|
||||
@b(i) = @a(i);
|
||||
std::sort(&@b(0), &@b(in0_shape0));
|
||||
"""
|
||||
)
|
||||
assert (b.data==[1,2,3]).all()
|
||||
```
|
||||
Example-2::
|
||||
|
||||
Example-3:
|
||||
This example shows how to set multiple outputs in code op.
|
||||
```
|
||||
a = jt.array([3,2,1])
|
||||
b,c = jt.code([(1,), (1,)], [a.dtype, a.dtype], [a],
|
||||
cpu_header="""
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
@b(0) = @c(0) = @a(0);
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
@b(0) = std::min(@b(0), @a(i));
|
||||
@c(0) = std::max(@c(0), @a(i));
|
||||
}
|
||||
cout << "min:" << @b(0) << " max:" << @c(0) << endl;
|
||||
"""
|
||||
)
|
||||
assert b.data == 1, b
|
||||
assert c.data == 3, c
|
||||
```
|
||||
a = jt.array([3,2,1])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cpu_header="""
|
||||
#include <algorithm>
|
||||
@alias(a, in0)
|
||||
@alias(b, out)
|
||||
""",
|
||||
cpu_src="""
|
||||
for (int i=0; i<a_shape0; i++)
|
||||
@b(i) = @a(i);
|
||||
std::sort(&@b(0), &@b(in0_shape0));
|
||||
"""
|
||||
)
|
||||
assert (b.data==[1,2,3]).all()
|
||||
|
||||
Example-4:
|
||||
This example shows how to use dynamic shape of jittor variables.
|
||||
```
|
||||
a = jt.array([5,-4,3,-2,1])
|
||||
Example-3::
|
||||
|
||||
#This example shows how to set multiple outputs in code op.
|
||||
a = jt.array([3,2,1])
|
||||
b,c = jt.code([(1,), (1,)], [a.dtype, a.dtype], [a],
|
||||
cpu_header="""
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
@b(0) = @c(0) = @a(0);
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
@b(0) = std::min(@b(0), @a(i));
|
||||
@c(0) = std::max(@c(0), @a(i));
|
||||
}
|
||||
cout << "min:" << @b(0) << " max:" << @c(0) << endl;
|
||||
"""
|
||||
)
|
||||
assert b.data == 1, b
|
||||
assert c.data == 3, c
|
||||
|
||||
Example-4::
|
||||
|
||||
#This example shows how to use dynamic shape of jittor variables.
|
||||
a = jt.array([5,-4,3,-2,1])
|
||||
|
||||
# negtive shape for max size of vary dimension
|
||||
b,c = jt.code([(-5,), (-5,)], [a.dtype, a.dtype], [a],
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
int num_b=0, num_c=0;
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
if (@a(i)>0)
|
||||
@b(num_b++) = @a(i);
|
||||
else
|
||||
@c(num_c++) = @a(i);
|
||||
}
|
||||
b->set_shape({num_b});
|
||||
c->set_shape({num_c});
|
||||
"""
|
||||
)
|
||||
assert (b.data == [5,3,1]).all()
|
||||
assert (c.data == [-4,-2]).all()
|
||||
|
||||
|
||||
CUDA Example-1::
|
||||
|
||||
#This example shows how to use CUDA in code op.
|
||||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
|
||||
CUDA Example-2::
|
||||
|
||||
# negtive shape for max size of vary dimension
|
||||
b,c = jt.code([(-5,), (-5,)], [a.dtype, a.dtype], [a],
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
int num_b=0, num_c=0;
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
if (@a(i)>0)
|
||||
@b(num_b++) = @a(i);
|
||||
else
|
||||
@c(num_c++) = @a(i);
|
||||
}
|
||||
b->set_shape({num_b});
|
||||
c->set_shape({num_c});
|
||||
"""
|
||||
)
|
||||
assert (b.data == [5,3,1]).all()
|
||||
assert (c.data == [-4,-2]).all()
|
||||
```
|
||||
|
||||
|
||||
CUDA Example-1:
|
||||
This example shows how to use CUDA in code op.
|
||||
```
|
||||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
```
|
||||
|
||||
CUDA Example-2:
|
||||
This example shows how to use multi dimension data with CUDA.
|
||||
```
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
kernel3<<<32, 32>>>(@ARGS);
|
||||
'''])
|
||||
```
|
||||
#This example shows how to use multi dimension data with CUDA.
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
kernel3<<<32, 32>>>(@ARGS);
|
||||
'''])
|
||||
*/
|
||||
CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
|
||||
|
|
|
@ -16,17 +16,16 @@ struct ConcatOp : Op {
|
|||
/**
|
||||
Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
|
||||
@param[in] x input var list for concat
|
||||
* [in] x: input var list for concat
|
||||
|
||||
@param[in] dim concat which dim
|
||||
* [in] dim: concat which dim
|
||||
|
||||
@param[out] out concat result
|
||||
* [out] out: concat result
|
||||
|
||||
Example
|
||||
```
|
||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
```
|
||||
Example::
|
||||
|
||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
*/
|
||||
ConcatOp(vector<Var*>&& x, int dim=0);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
|
|
|
@ -15,29 +15,26 @@ struct IndexOp : Op {
|
|||
/**
|
||||
Index Operator generate index of shape.
|
||||
|
||||
It performs equivalent Python-pseudo implementation below:
|
||||
It performs equivalent Python-pseudo implementation below::
|
||||
|
||||
```
|
||||
n = len(shape)-1
|
||||
x = np.zeros(shape, dtype)
|
||||
for i0 in range(shape[0]): # 1-st loop
|
||||
for i1 in range(shape[1]): # 2-nd loop
|
||||
...... # many loops
|
||||
for in in range(shape[n]) # n+1 -th loop
|
||||
x[i0,i1,...,in] = i@dim
|
||||
```
|
||||
n = len(shape)-1
|
||||
x = np.zeros(shape, dtype)
|
||||
for i0 in range(shape[0]): # 1-st loop
|
||||
for i1 in range(shape[1]): # 2-nd loop
|
||||
...... # many loops
|
||||
for in in range(shape[n]) # n+1 -th loop
|
||||
x[i0,i1,...,in] = i@dim
|
||||
|
||||
@param[in] shape the output shape, a integer array
|
||||
@param[in] dim the dim of the index.
|
||||
@param[in] dtype the data type string, default int32
|
||||
* [in] shape: the output shape, a integer array
|
||||
* [in] dim: the dim of the index.
|
||||
* [in] dtype: the data type string, default int32
|
||||
|
||||
Example
|
||||
```
|
||||
print(jt.index([2,2], 0)())
|
||||
# output: [[0,0],[1,1]]
|
||||
print(jt.index([2,2], 1)())
|
||||
# output: [[0,1],[0,1]]
|
||||
```
|
||||
Example::
|
||||
|
||||
print(jt.index([2,2], 0)())
|
||||
# output: [[0,0],[1,1]]
|
||||
print(jt.index([2,2], 1)())
|
||||
# output: [[0,1],[0,1]]
|
||||
*/
|
||||
IndexOp(NanoVector shape, int64 dim, NanoString dtype=ns_int32);
|
||||
// @attrs(multiple_outputs)
|
||||
|
|
|
@ -18,77 +18,74 @@ struct ReindexOp : Op {
|
|||
vector<Var*> extras;
|
||||
/**
|
||||
Reindex Operator is a one-to-many map operator.
|
||||
It performs equivalent Python-pseudo implementation below:
|
||||
```
|
||||
# input is x, output is y
|
||||
n = len(shape)-1
|
||||
m = len(x.shape)-1
|
||||
k = len(overflow_conditions)-1
|
||||
y = np.zeros(shape, x.dtype)
|
||||
for i0 in range(shape[0]): # 1-st loop
|
||||
for i1 in range(shape[1]): # 2-nd loop
|
||||
...... # many loops
|
||||
for in in range(shape[n]) # n+1 -th loop
|
||||
if is_overflow(i0,i1,...,in):
|
||||
y[i0,i1,...,in] = overflow_value
|
||||
else:
|
||||
# indexes[i] is a c++ style integer expression consisting of i0,i1,...,in
|
||||
y[i0,i1,...,in] = x[indexes[0],indexes[1],...,indexes[m]]
|
||||
It performs equivalent Python-pseudo implementation below::
|
||||
|
||||
# is_overflow is defined as following
|
||||
def is_overflow(i0,i1,...,in):
|
||||
return (
|
||||
indexes[0] < 0 || indexes[0] >= x.shape[0] ||
|
||||
indexes[1] < 0 || indexes[1] >= x.shape[1] ||
|
||||
......
|
||||
indexes[m] < 0 || indexes[m] >= x.shape[m] ||
|
||||
# input is x, output is y
|
||||
n = len(shape)-1
|
||||
m = len(x.shape)-1
|
||||
k = len(overflow_conditions)-1
|
||||
y = np.zeros(shape, x.dtype)
|
||||
for i0 in range(shape[0]): # 1-st loop
|
||||
for i1 in range(shape[1]): # 2-nd loop
|
||||
...... # many loops
|
||||
for in in range(shape[n]) # n+1 -th loop
|
||||
if is_overflow(i0,i1,...,in):
|
||||
y[i0,i1,...,in] = overflow_value
|
||||
else:
|
||||
# indexes[i] is a c++ style integer expression consisting of i0,i1,...,in
|
||||
y[i0,i1,...,in] = x[indexes[0],indexes[1],...,indexes[m]]
|
||||
|
||||
# overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in
|
||||
overflow_conditions[0] ||
|
||||
overflow_conditions[1] ||
|
||||
......
|
||||
overflow_conditions[k]
|
||||
)
|
||||
```
|
||||
# is_overflow is defined as following
|
||||
def is_overflow(i0,i1,...,in):
|
||||
return (
|
||||
indexes[0] < 0 || indexes[0] >= x.shape[0] ||
|
||||
indexes[1] < 0 || indexes[1] >= x.shape[1] ||
|
||||
......
|
||||
indexes[m] < 0 || indexes[m] >= x.shape[m] ||
|
||||
|
||||
# overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in
|
||||
overflow_conditions[0] ||
|
||||
overflow_conditions[1] ||
|
||||
......
|
||||
overflow_conditions[k]
|
||||
)
|
||||
----------------
|
||||
@param[in] x A input jittor Var
|
||||
* [in] x: A input jittor Var
|
||||
|
||||
@param[in] shape the output shape, a integer array
|
||||
* [in] shape: the output shape, a integer array
|
||||
|
||||
@param[in] indexes array of c++ style integer expression, its length should be
|
||||
the same with the number of dimension of x, some buildin variables it can use
|
||||
are: XDIM, xshape0, ..., xshapen, xstride0, ..., xstriden
|
||||
* [in] indexes: array of c++ style integer expression, its length should be the same with the number of dimension of x, some buildin variables it can use are::
|
||||
|
||||
XDIM, xshape0, ..., xshapen, xstride0, ..., xstriden
|
||||
YDIM, yshape0, ..., yshapem, ystride0, ..., ystridem
|
||||
i0, i1, ..., in
|
||||
@e0(...), @e1(...) for extras input index
|
||||
e0p, e1p , ... for extras input pointer
|
||||
|
||||
@param[in] overflow_value overflow value
|
||||
* [in] overflow_value: overflow value
|
||||
|
||||
@param[in] overflow_conditions array of c++ style boolean expression, it length
|
||||
can be vary. the buildin variables it can use are the same with indexes
|
||||
* [in] overflow_conditions: array of c++ style boolean expression, it length can be vary. the buildin variables it can use are the same with indexes
|
||||
|
||||
@param[in] extras: extra var used for index
|
||||
* [in] extras: extra var used for index
|
||||
|
||||
----------------
|
||||
Example
|
||||
Convolution implemented by reindex operation:
|
||||
```
|
||||
def conv(x, w):
|
||||
N,H,W,C = x.shape
|
||||
Kh, Kw, _C, Kc = w.shape
|
||||
assert C==_C
|
||||
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
|
||||
'i0', # Nid
|
||||
'i1+i3', # Hid+Khid
|
||||
'i2+i4', # Wid+KWid
|
||||
'i5', # Cid
|
||||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, C
|
||||
return y, yy
|
||||
```
|
||||
Convolution implemented by reindex operation::
|
||||
|
||||
def conv(x, w):
|
||||
N,H,W,C = x.shape
|
||||
Kh, Kw, _C, Kc = w.shape
|
||||
assert C==_C
|
||||
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
|
||||
'i0', # Nid
|
||||
'i1+i3', # Hid+Khid
|
||||
'i2+i4', # Wid+KWid
|
||||
'i5', # Cid
|
||||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, C
|
||||
return y, yy
|
||||
*/
|
||||
ReindexOp(Var* x, NanoVector shape, vector<string>&& indexes, float64 overflow_value=0, vector<string>&& overflow_conditions={}, vector<Var*>&& extras={});
|
||||
/** Alias x.reindex([i,j,k]) ->
|
||||
|
|
|
@ -17,73 +17,70 @@ struct ReindexReduceOp : Op {
|
|||
vector<Var*> extras;
|
||||
/**
|
||||
Reindex Reduce Operator is a many-to-one map operator.
|
||||
It performs equivalent Python-pseudo implementation below:
|
||||
```
|
||||
# input is y, output is x
|
||||
n = len(y.shape)-1
|
||||
m = len(shape)-1
|
||||
k = len(overflow_conditions)-1
|
||||
x = np.zeros(shape, y.dtype)
|
||||
x[:] = initial_value(op)
|
||||
for i0 in range(y.shape[0]): # 1-st loop
|
||||
for i1 in range(y.shape[1]): # 2-nd loop
|
||||
...... # many loops
|
||||
for in in range(y.shape[n]) # n+1 -th loop
|
||||
# indexes[i] is a c++ style integer expression consisting of i0,i1,...,in
|
||||
xi0,xi1,...,xim = indexes[0],indexes[1],...,indexes[m]
|
||||
if not is_overflow(xi0,xi1,...,xim):
|
||||
x[xi0,xi1,...,xim] = op(x[xi0,xi1,...,xim], y[i0,i1,...,in])
|
||||
It performs equivalent Python-pseudo implementation below::
|
||||
|
||||
# is_overflow is defined as following
|
||||
def is_overflow(xi0,xi1,...,xim):
|
||||
return (
|
||||
xi0 < 0 || xi0 >= shape[0] ||
|
||||
xi1 < 0 || xi1 >= shape[1] ||
|
||||
......
|
||||
xim < 0 || xim >= shape[m] ||
|
||||
# input is y, output is x
|
||||
n = len(y.shape)-1
|
||||
m = len(shape)-1
|
||||
k = len(overflow_conditions)-1
|
||||
x = np.zeros(shape, y.dtype)
|
||||
x[:] = initial_value(op)
|
||||
for i0 in range(y.shape[0]): # 1-st loop
|
||||
for i1 in range(y.shape[1]): # 2-nd loop
|
||||
...... # many loops
|
||||
for in in range(y.shape[n]) # n+1 -th loop
|
||||
# indexes[i] is a c++ style integer expression consisting of i0,i1,...,in
|
||||
xi0,xi1,...,xim = indexes[0],indexes[1],...,indexes[m]
|
||||
if not is_overflow(xi0,xi1,...,xim):
|
||||
x[xi0,xi1,...,xim] = op(x[xi0,xi1,...,xim], y[i0,i1,...,in])
|
||||
|
||||
# overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in
|
||||
overflow_conditions[0] ||
|
||||
overflow_conditions[1] ||
|
||||
......
|
||||
overflow_conditions[k]
|
||||
)
|
||||
```
|
||||
# is_overflow is defined as following
|
||||
def is_overflow(xi0,xi1,...,xim):
|
||||
return (
|
||||
xi0 < 0 || xi0 >= shape[0] ||
|
||||
xi1 < 0 || xi1 >= shape[1] ||
|
||||
......
|
||||
xim < 0 || xim >= shape[m] ||
|
||||
|
||||
@param[in] y A input jittor Var
|
||||
# overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in
|
||||
overflow_conditions[0] ||
|
||||
overflow_conditions[1] ||
|
||||
......
|
||||
overflow_conditions[k]
|
||||
)
|
||||
|
||||
* [in] y: A input jittor Var
|
||||
|
||||
@param[in] op a string represent the reduce operation type
|
||||
* [in] op: a string represent the reduce operation type
|
||||
|
||||
@param[in] shape the output shape, a integer array
|
||||
* [in] shape: the output shape, a integer array
|
||||
|
||||
@param[in] indexes array of c++ style integer expression, its length should be
|
||||
the same with length of shape, some buildin variables it can use
|
||||
are: XDIM, xshape0, ..., xshapem, xstride0, ..., xstridem
|
||||
* [in] indexes: array of c++ style integer expression, its length should be the same with length of shape, some buildin variables it can use are::
|
||||
|
||||
XDIM, xshape0, ..., xshapem, xstride0, ..., xstridem
|
||||
YDIM, yshape0, ..., yshapen, ystride0, ..., ystriden
|
||||
i0, i1, ..., in
|
||||
@e0(...), @e1(...) for extras input index
|
||||
e0p, e1p , ... for extras input pointer
|
||||
|
||||
@param[in] overflow_conditions array of c++ style boolean expression, it length
|
||||
can be vary. the buildin variables it can use are the same with indexes.
|
||||
* [in] overflow_conditions: array of c++ style boolean expression, it length can be vary. the buildin variables it can use are the same with indexes.
|
||||
|
||||
@param[in] extras extra var used for index
|
||||
* [in] extras: extra var used for index
|
||||
|
||||
Example
|
||||
|
||||
Pooling implemented by reindex operation:
|
||||
```
|
||||
def pool(x, size, op):
|
||||
N,H,W,C = x.shape
|
||||
h = (H+size-1)//size
|
||||
w = (W+size-1)//size
|
||||
return x.reindex_reduce(op, [N,h,w,C], [
|
||||
"i0", # Nid
|
||||
f"i1/{size}", # Hid
|
||||
f"i2/{size}", # Wid
|
||||
"i3", # Cid
|
||||
])
|
||||
```
|
||||
Pooling implemented by reindex operation::
|
||||
|
||||
def pool(x, size, op):
|
||||
N,H,W,C = x.shape
|
||||
h = (H+size-1)//size
|
||||
w = (W+size-1)//size
|
||||
return x.reindex_reduce(op, [N,h,w,C], [
|
||||
"i0", # Nid
|
||||
f"i1/{size}", # Hid
|
||||
f"i2/{size}", # Wid
|
||||
"i3", # Cid
|
||||
])
|
||||
*/
|
||||
ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector<string>&& indexes, vector<string>&& overflow_conditions={}, vector<Var*>&& extras={});
|
||||
|
||||
|
|
|
@ -15,17 +15,16 @@ struct WhereOp : Op {
|
|||
/**
|
||||
Where Operator generate index of true condition.
|
||||
|
||||
@param[in] cond condition for index generation
|
||||
* [in] cond: condition for index generation
|
||||
|
||||
@param[in] dtype type of return indexes
|
||||
* [in] dtype: type of return indexes
|
||||
|
||||
@param[out] out return an array of indexes, same length with number of dims of cond
|
||||
* [out] out: return an array of indexes, same length with number of dims of cond
|
||||
|
||||
Example
|
||||
```
|
||||
jt.where([[0,0,1],[1,0,0]])
|
||||
# return ( [0,2], [1,0] )
|
||||
```
|
||||
Example::
|
||||
|
||||
jt.where([[0,0,1],[1,0,0]])
|
||||
# return ( [0,2], [1,0] )
|
||||
*/
|
||||
// @attrs(multiple_outputs)
|
||||
WhereOp(Var* cond, NanoString dtype=ns_int32);
|
||||
|
|
|
@ -23,9 +23,9 @@ typedef struct {
|
|||
|
||||
/* Parse the pagemap entry for the given virtual address.
|
||||
*
|
||||
* @param[out] entry the parsed entry
|
||||
* @param[in] pagemap_fd file descriptor to an open /proc/pid/pagemap file
|
||||
* @param[in] vaddr virtual address to get entry for
|
||||
* * [out] entry: the parsed entry
|
||||
* * [in] pagemap_fd: file descriptor to an open /proc/pid/pagemap file
|
||||
* * [in] vaddr: virtual address to get entry for
|
||||
* @return 0 for success, 1 for failure
|
||||
*/
|
||||
int pagemap_get_entry(PagemapEntry* entry, int pagemap_fd, uintptr_t vaddr)
|
||||
|
@ -55,8 +55,8 @@ int pagemap_get_entry(PagemapEntry* entry, int pagemap_fd, uintptr_t vaddr)
|
|||
|
||||
/* Convert the given virtual address to physical using /proc/self/pagemap.
|
||||
*
|
||||
* @param[out] paddr physical address
|
||||
* @param[in] vaddr virtual address to get entry for
|
||||
* * [out] paddr: physical address
|
||||
* * [in] vaddr: virtual address to get entry for
|
||||
* @return 0 for success, 1 for failure
|
||||
*/
|
||||
int virt_to_phys_user(uintptr_t* paddr, uintptr_t vaddr)
|
||||
|
|
Loading…
Reference in New Issue