burn/examples/notebook/basic-tensor-op.ipynb

140 lines
3.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# This notebook demonstrates basic tensor operations in Burn."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "rust"
}
},
"outputs": [],
"source": [
"// Dependency declarations for the notebook. WARNING: It may take a while to compile the first time.\n",
"\n",
"// The syntax is similar to the one used in the Cargo.toml file. Just prefix with :dep\n",
"// See: https://github.com/evcxr/evcxr/blob/main/COMMON.md\n",
"\n",
":dep burn = {path = \"../../burn\"}\n",
":dep burn-ndarray = {path = \"../../burn-ndarray\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "rust"
}
},
"outputs": [],
"source": [
"// Import packages\n",
"use burn::prelude::*;\n",
"use burn_ndarray::NdArray;\n",
"\n",
"// Type alias for the backend\n",
"type B = NdArray<f32>;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tensor creation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "rust"
}
},
"outputs": [],
"source": [
"// A device the tensors will be stored on\n",
"let device = <B as Backend>::Device::default();\n",
"// Create an empty tensor for a given shape\n",
"let tensor: Tensor<B, 3> = Tensor::empty([1, 2, 3], &device);\n",
"println!(\"Empty tensor: {}\", tensor);\n",
"\n",
"// Create a tensor from a slice of floats\n",
"let tensor: Tensor<B, 2> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\n",
"println!(\"Tensor from slice: {}\", tensor);\n",
"\n",
"// Create a random tensor\n",
"use burn::tensor::Distribution;\n",
"let tensor: Tensor<B, 1> = Tensor::random([5], Distribution::Default, &device);\n",
"println!(\"Random tensor: {}\", tensor);\n",
"\n",
"// Create a tensor using fill values, zeros, or ones\n",
"let tensor: Tensor<B,2> = Tensor::full([2, 2], 7.0, &device);\n",
"let tensor: Tensor<B,2> = Tensor::zeros([2, 2], &device);\n",
"let tensor: Tensor<B,2> = Tensor::ones([2, 2], &device);\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tensor Operations\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "rust"
}
},
"outputs": [],
"source": [
"let x1: Tensor<B,2> = Tensor::ones([2, 2], &device);\n",
"let x2: Tensor<B,2> = Tensor::full([2, 2], 7.0, &device);\n",
"\n",
"let x3 = x1 + x2;\n",
"\n",
"println!(\"x3 = {}\", x3);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "rust"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Rust",
"language": "rust",
"name": "rust"
},
"language_info": {
"codemirror_mode": "rust",
"file_extension": ".rs",
"mimetype": "text/rust",
"name": "Rust",
"pygment_lexer": "rust",
"version": ""
}
},
"nbformat": 4,
"nbformat_minor": 4
}