mirror of https://github.com/tracel-ai/burn.git
174 lines
4.1 KiB
Plaintext
174 lines
4.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# This notebook demonstrates basic tensor operations in Burn."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "rust"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"// Dependency declarations for the notebook. WARNING: It may take a while to compile the first time.\n",
|
|
"\n",
|
|
"// The syntax is similar to the one used in the Cargo.toml file. Just prefix with :dep\n",
|
|
"// See: https://github.com/evcxr/evcxr/blob/main/COMMON.md\n",
|
|
"\n",
|
|
":dep burn = {path = \"../../burn\"}\n",
|
|
":dep burn-ndarray = {path = \"../../burn-ndarray\"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "rust"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"// Import packages\n",
|
|
"use burn::tensor::Tensor;\n",
|
|
"use burn::backend::ndarray::NdArray;\n",
|
|
"\n",
|
|
"// Type alias for the backend\n",
|
|
"type B = NdArray<f32>;"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Tensor creation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "rust"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Empty tensor: Tensor {\n",
|
|
" data: [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],\n",
|
|
" shape: [1, 2, 3],\n",
|
|
" device: Cpu,\n",
|
|
" backend: \"ndarray\",\n",
|
|
" kind: \"Float\",\n",
|
|
" dtype: \"f32\",\n",
|
|
"}\n",
|
|
"Tensor from slice: Tensor {\n",
|
|
" data: [[1.0, 2.0], [3.0, 4.0]],\n",
|
|
" shape: [2, 2],\n",
|
|
" device: Cpu,\n",
|
|
" backend: \"ndarray\",\n",
|
|
" kind: \"Float\",\n",
|
|
" dtype: \"f32\",\n",
|
|
"}\n",
|
|
"Random tensor: Tensor {\n",
|
|
" data: [0.16685265, 0.7217095, 0.35741878, 0.49403405, 0.27360022],\n",
|
|
" shape: [5],\n",
|
|
" device: Cpu,\n",
|
|
" backend: \"ndarray\",\n",
|
|
" kind: \"Float\",\n",
|
|
" dtype: \"f32\",\n",
|
|
"}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"// Create an empty tensor for a given shape\n",
|
|
"let tensor: Tensor<B, 3> = Tensor::empty([1, 2, 3]);\n",
|
|
"println!(\"Empty tensor: {}\", tensor);\n",
|
|
"\n",
|
|
"// Create a tensor from a slice of floats\n",
|
|
"let tensor: Tensor<B, 2> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0]).reshape([2, 2]);\n",
|
|
"println!(\"Tensor from slice: {}\", tensor);\n",
|
|
"\n",
|
|
"// Create a random tensor\n",
|
|
"use burn::tensor::Distribution;\n",
|
|
"let tensor: Tensor<B, 1> = Tensor::random([5], Distribution::Default);\n",
|
|
"println!(\"Random tensor: {}\", tensor);\n",
|
|
"\n",
|
|
"// Create a tensor using fill values, zeros, or ones\n",
|
|
"let tensor: Tensor<B,2> = Tensor::full([2, 2], 7.0);\n",
|
|
"let tensor: Tensor<B,2> = Tensor::zeros([2, 2]);\n",
|
|
"let tensor: Tensor<B,2> = Tensor::ones([2, 2]);\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Tensor Operations\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "rust"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"x3 = Tensor {\n",
|
|
" data: [[8.0, 8.0], [8.0, 8.0]],\n",
|
|
" shape: [2, 2],\n",
|
|
" device: Cpu,\n",
|
|
" backend: \"ndarray\",\n",
|
|
" kind: \"Float\",\n",
|
|
" dtype: \"f32\",\n",
|
|
"}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"let x1: Tensor<B,2> = Tensor::ones([2, 2]);\n",
|
|
"let x2: Tensor<B,2> = Tensor::full([2, 2], 7.0);\n",
|
|
"\n",
|
|
"let x3 = x1 + x2;\n",
|
|
"\n",
|
|
"println!(\"x3 = {}\", x3);"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Rust",
|
|
"language": "rust",
|
|
"name": "rust"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": "rust",
|
|
"file_extension": ".rs",
|
|
"mimetype": "text/rust",
|
|
"name": "Rust",
|
|
"pygment_lexer": "rust",
|
|
"version": ""
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|