burn/examples/notebook/plots.ipynb

111 lines
2.4 KiB
Plaintext
Raw Permalink Normal View History

2023-08-17 06:17:12 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# This notebook demonstrates basic tensor operations in Burn."
]
},
{
"cell_type": "code",
2024-01-25 02:32:01 +08:00
"execution_count": null,
2023-08-17 06:17:12 +08:00
"metadata": {
"vscode": {
"languageId": "rust"
}
},
"outputs": [],
"source": [
"// Dependency declarations for the notebook. WARNING: It may take a while to compile the first time.\n",
"\n",
"// The syntax is similar to the one used in the Cargo.toml file. Just prefix with :dep\n",
"// See: https://github.com/evcxr/evcxr/blob/main/COMMON.md\n",
"\n",
":dep burn = {path = \"../../crates/burn\"}\n",
":dep burn-ndarray = {path = \"../../crates/burn-ndarray\"}\n",
2023-08-17 06:17:12 +08:00
"\n",
"// The following dependencies are used for plotting\n",
":dep image = \"0.23\"\n",
":dep evcxr_image = \"1.1\""
]
},
{
"cell_type": "code",
2024-01-25 02:32:01 +08:00
"execution_count": null,
2023-08-17 06:17:12 +08:00
"metadata": {
"vscode": {
"languageId": "rust"
}
},
"outputs": [],
"source": [
"// Import packages\n",
"use burn::tensor::Tensor;\n",
2024-01-25 02:32:01 +08:00
"use burn_ndarray::NdArray\n",
2023-08-17 06:17:12 +08:00
"\n",
"// Import plotting library\n",
"use evcxr_image::ImageDisplay;\n",
"\n",
"// Type alias for the backend\n",
"type B = NdArray<f32>;"
2023-08-17 06:17:12 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Image from tensor"
]
},
{
"cell_type": "code",
2024-01-25 02:32:01 +08:00
"execution_count": null,
2023-08-17 06:17:12 +08:00
"metadata": {
"vscode": {
"languageId": "rust"
}
},
2024-01-25 02:32:01 +08:00
"outputs": [],
2023-08-17 06:17:12 +08:00
"source": [
"// Create a random tensor\n",
"use burn::tensor::Distribution;\n",
2024-01-25 02:32:01 +08:00
"let tensor: Tensor<B, 3> = Tensor::random([3, 256, 256], Distribution::Default, &Default::default());\n",
2023-08-17 06:17:12 +08:00
"\n",
"// TODO Use tenso to display plots\n",
"image::ImageBuffer::from_fn(256, 256, |x, y| {\n",
" if (x as i32 - y as i32).abs() < 3 {\n",
" image::Rgb([0, 0, 255])\n",
" } else {\n",
" image::Rgb([0, 0, 0])\n",
" }\n",
"})\n"
]
2024-01-25 02:32:01 +08:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2023-08-17 06:17:12 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Rust",
"language": "rust",
"name": "rust"
},
"language_info": {
"codemirror_mode": "rust",
"file_extension": ".rs",
"mimetype": "text/rust",
"name": "Rust",
"pygment_lexer": "rust",
"version": ""
2024-01-25 02:32:01 +08:00
}
2023-08-17 06:17:12 +08:00
},
"nbformat": 4,
2024-01-25 02:32:01 +08:00
"nbformat_minor": 4
2023-08-17 06:17:12 +08:00
}