Move the tensor-tools binary in a separate crate. (#1969)

This commit is contained in:
Laurent Mazare 2024-03-30 15:49:37 +01:00 committed by GitHub
parent b190fd8592
commit 3144150b8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 22 deletions

View File

@ -9,6 +9,7 @@ members = [
"candle-transformers", "candle-transformers",
"candle-wasm-examples/*", "candle-wasm-examples/*",
"candle-wasm-tests", "candle-wasm-tests",
"tensor-tools",
] ]
exclude = [ exclude = [
"candle-flash-attn", "candle-flash-attn",

View File

@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via: `tensor-tools` command line utility via:
```bash ```bash
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf $ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
``` ```
## Using custom models ## Using custom models

16
tensor-tools/Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "tensor-tools"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
anyhow = { workspace = true }
candle = { workspace = true }
clap = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }

View File

@ -1,5 +1,5 @@
use candle_core::quantized::{gguf_file, GgmlDType, QTensor}; use candle::quantized::{gguf_file, GgmlDType, QTensor};
use candle_core::{Device, Result}; use candle::{Device, Result};
use clap::{Parser, Subcommand, ValueEnum}; use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*; use rayon::prelude::*;
@ -177,10 +177,10 @@ fn run_print(
device: &Device, device: &Device,
) -> Result<()> { ) -> Result<()> {
if full { if full {
candle_core::display::set_print_options_full(); candle::display::set_print_options_full();
} }
if let Some(line_width) = line_width { if let Some(line_width) = line_width {
candle_core::display::set_line_width(line_width) candle::display::set_line_width(line_width)
} }
let format = match format { let format = match format {
Some(format) => format, Some(format) => format,
@ -196,7 +196,7 @@ fn run_print(
}; };
match format { match format {
Format::Npz => { Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?; let tensors = candle::npy::NpzTensors::new(file)?;
for name in names.iter() { for name in names.iter() {
println!("==== {name} ===="); println!("==== {name} ====");
match tensors.get(name)? { match tensors.get(name)? {
@ -206,8 +206,8 @@ fn run_print(
} }
} }
Format::Safetensors => { Format::Safetensors => {
use candle_core::safetensors::Load; use candle::safetensors::Load;
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? }; let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect(); let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
for name in names.iter() { for name in names.iter() {
println!("==== {name} ===="); println!("==== {name} ====");
@ -221,7 +221,7 @@ fn run_print(
} }
} }
Format::Pth => { Format::Pth => {
let pth_file = candle_core::pickle::PthTensors::new(file, None)?; let pth_file = candle::pickle::PthTensors::new(file, None)?;
for name in names.iter() { for name in names.iter() {
println!("==== {name} ===="); println!("==== {name} ====");
match pth_file.get(name)? { match pth_file.get(name)? {
@ -233,11 +233,11 @@ fn run_print(
} }
} }
Format::Pickle => { Format::Pickle => {
candle_core::bail!("pickle format is not supported for print") candle::bail!("pickle format is not supported for print")
} }
Format::Ggml => { Format::Ggml => {
let mut file = std::fs::File::open(file)?; let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
for name in names.iter() { for name in names.iter() {
println!("==== {name} ===="); println!("==== {name} ====");
match content.tensors.get(name) { match content.tensors.get(name) {
@ -287,7 +287,7 @@ fn run_ls(
}; };
match format { match format {
Format::Npz => { Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?; let tensors = candle::npy::NpzTensors::new(file)?;
let mut names = tensors.names(); let mut names = tensors.names();
names.sort(); names.sort();
for name in names { for name in names {
@ -299,12 +299,12 @@ fn run_ls(
} }
} }
Format::Safetensors => { Format::Safetensors => {
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? }; let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let mut tensors = tensors.tensors(); let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() { for (name, view) in tensors.iter() {
let dtype = view.dtype(); let dtype = view.dtype();
let dtype = match candle_core::DType::try_from(dtype) { let dtype = match candle::DType::try_from(dtype) {
Ok(dtype) => format!("{dtype:?}"), Ok(dtype) => format!("{dtype:?}"),
Err(_) => format!("{dtype:?}"), Err(_) => format!("{dtype:?}"),
}; };
@ -313,7 +313,7 @@ fn run_ls(
} }
} }
Format::Pth => { Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?; let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name)); tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() { for tensor_info in tensors.iter() {
println!( println!(
@ -330,7 +330,7 @@ fn run_ls(
Format::Pickle => { Format::Pickle => {
let file = std::fs::File::open(file)?; let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file); let mut reader = std::io::BufReader::new(file);
let mut stack = candle_core::pickle::Stack::empty(); let mut stack = candle::pickle::Stack::empty();
stack.read_loop(&mut reader)?; stack.read_loop(&mut reader)?;
for (i, obj) in stack.stack().iter().enumerate() { for (i, obj) in stack.stack().iter().enumerate() {
println!("{i} {obj:?}"); println!("{i} {obj:?}");
@ -338,7 +338,7 @@ fn run_ls(
} }
Format::Ggml => { Format::Ggml => {
let mut file = std::fs::File::open(file)?; let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>(); let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() { for (name, qtensor) in tensors.iter() {
@ -374,7 +374,7 @@ fn run_quantize_safetensors(
let mut out_file = std::fs::File::create(out_file)?; let mut out_file = std::fs::File::create(out_file)?;
let mut tensors = std::collections::HashMap::new(); let mut tensors = std::collections::HashMap::new();
for in_file in in_files.iter() { for in_file in in_files.iter() {
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?; let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?;
tensors.extend(in_tensors) tensors.extend(in_tensors)
} }
println!("tensors: {}", tensors.len()); println!("tensors: {}", tensors.len());
@ -416,7 +416,7 @@ fn run_dequantize(
let tensor = tensor.dequantize(device)?; let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor); tensors.insert(tensor_name.to_string(), tensor);
} }
candle_core::safetensors::save(&tensors, out_file)?; candle::safetensors::save(&tensors, out_file)?;
Ok(()) Ok(())
} }
@ -428,11 +428,11 @@ fn run_quantize(
device: &Device, device: &Device,
) -> Result<()> { ) -> Result<()> {
if in_files.is_empty() { if in_files.is_empty() {
candle_core::bail!("no specified input files") candle::bail!("no specified input files")
} }
if let Some(extension) = out_file.extension() { if let Some(extension) = out_file.extension() {
if extension == "safetensors" { if extension == "safetensors" {
candle_core::bail!("the generated file cannot use the safetensors extension") candle::bail!("the generated file cannot use the safetensors extension")
} }
} }
if let Some(extension) = in_files[0].extension() { if let Some(extension) = in_files[0].extension() {
@ -442,7 +442,7 @@ fn run_quantize(
} }
if in_files.len() != 1 { if in_files.len() != 1 {
candle_core::bail!("only a single in-file can be used when quantizing gguf files") candle::bail!("only a single in-file can be used when quantizing gguf files")
} }
// Open the out file early so as to fail directly on missing directories etc. // Open the out file early so as to fail directly on missing directories etc.