Move the tensor-tools binary in a separate crate. (#1969)

This commit is contained in:
Laurent Mazare 2024-03-30 15:49:37 +01:00 committed by GitHub
parent b190fd8592
commit 3144150b8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 22 deletions

View File

@ -9,6 +9,7 @@ members = [
"candle-transformers",
"candle-wasm-examples/*",
"candle-wasm-tests",
"tensor-tools",
]
exclude = [
"candle-flash-attn",

View File

@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
```bash
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```
## Using custom models

16
tensor-tools/Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "tensor-tools"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
anyhow = { workspace = true }
candle = { workspace = true }
clap = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }

View File

@ -1,5 +1,5 @@
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
use candle_core::{Device, Result};
use candle::quantized::{gguf_file, GgmlDType, QTensor};
use candle::{Device, Result};
use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*;
@ -177,10 +177,10 @@ fn run_print(
device: &Device,
) -> Result<()> {
if full {
candle_core::display::set_print_options_full();
candle::display::set_print_options_full();
}
if let Some(line_width) = line_width {
candle_core::display::set_line_width(line_width)
candle::display::set_line_width(line_width)
}
let format = match format {
Some(format) => format,
@ -196,7 +196,7 @@ fn run_print(
};
match format {
Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?;
let tensors = candle::npy::NpzTensors::new(file)?;
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name)? {
@ -206,8 +206,8 @@ fn run_print(
}
}
Format::Safetensors => {
use candle_core::safetensors::Load;
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
use candle::safetensors::Load;
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
for name in names.iter() {
println!("==== {name} ====");
@ -221,7 +221,7 @@ fn run_print(
}
}
Format::Pth => {
let pth_file = candle_core::pickle::PthTensors::new(file, None)?;
let pth_file = candle::pickle::PthTensors::new(file, None)?;
for name in names.iter() {
println!("==== {name} ====");
match pth_file.get(name)? {
@ -233,11 +233,11 @@ fn run_print(
}
}
Format::Pickle => {
candle_core::bail!("pickle format is not supported for print")
candle::bail!("pickle format is not supported for print")
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
for name in names.iter() {
println!("==== {name} ====");
match content.tensors.get(name) {
@ -287,7 +287,7 @@ fn run_ls(
};
match format {
Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?;
let tensors = candle::npy::NpzTensors::new(file)?;
let mut names = tensors.names();
names.sort();
for name in names {
@ -299,12 +299,12 @@ fn run_ls(
}
}
Format::Safetensors => {
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() {
let dtype = view.dtype();
let dtype = match candle_core::DType::try_from(dtype) {
let dtype = match candle::DType::try_from(dtype) {
Ok(dtype) => format!("{dtype:?}"),
Err(_) => format!("{dtype:?}"),
};
@ -313,7 +313,7 @@ fn run_ls(
}
}
Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
@ -330,7 +330,7 @@ fn run_ls(
Format::Pickle => {
let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file);
let mut stack = candle_core::pickle::Stack::empty();
let mut stack = candle::pickle::Stack::empty();
stack.read_loop(&mut reader)?;
for (i, obj) in stack.stack().iter().enumerate() {
println!("{i} {obj:?}");
@ -338,7 +338,7 @@ fn run_ls(
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
@ -374,7 +374,7 @@ fn run_quantize_safetensors(
let mut out_file = std::fs::File::create(out_file)?;
let mut tensors = std::collections::HashMap::new();
for in_file in in_files.iter() {
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?;
tensors.extend(in_tensors)
}
println!("tensors: {}", tensors.len());
@ -416,7 +416,7 @@ fn run_dequantize(
let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
candle::safetensors::save(&tensors, out_file)?;
Ok(())
}
@ -428,11 +428,11 @@ fn run_quantize(
device: &Device,
) -> Result<()> {
if in_files.is_empty() {
candle_core::bail!("no specified input files")
candle::bail!("no specified input files")
}
if let Some(extension) = out_file.extension() {
if extension == "safetensors" {
candle_core::bail!("the generated file cannot use the safetensors extension")
candle::bail!("the generated file cannot use the safetensors extension")
}
}
if let Some(extension) = in_files[0].extension() {
@ -442,7 +442,7 @@ fn run_quantize(
}
if in_files.len() != 1 {
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
candle::bail!("only a single in-file can be used when quantizing gguf files")
}
// Open the out file early so as to fail directly on missing directories etc.