Move the tensor-tools binary in a separate crate. (#1969)
This commit is contained in:
parent
b190fd8592
commit
3144150b8d
|
@ -9,6 +9,7 @@ members = [
|
|||
"candle-transformers",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
"tensor-tools",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
|
|
|
@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
|
|||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
||||
|
||||
## Using custom models
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "tensor-tools"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
candle = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
|
@ -1,5 +1,5 @@
|
|||
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
||||
use candle_core::{Device, Result};
|
||||
use candle::quantized::{gguf_file, GgmlDType, QTensor};
|
||||
use candle::{Device, Result};
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use rayon::prelude::*;
|
||||
|
||||
|
@ -177,10 +177,10 @@ fn run_print(
|
|||
device: &Device,
|
||||
) -> Result<()> {
|
||||
if full {
|
||||
candle_core::display::set_print_options_full();
|
||||
candle::display::set_print_options_full();
|
||||
}
|
||||
if let Some(line_width) = line_width {
|
||||
candle_core::display::set_line_width(line_width)
|
||||
candle::display::set_line_width(line_width)
|
||||
}
|
||||
let format = match format {
|
||||
Some(format) => format,
|
||||
|
@ -196,7 +196,7 @@ fn run_print(
|
|||
};
|
||||
match format {
|
||||
Format::Npz => {
|
||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||
let tensors = candle::npy::NpzTensors::new(file)?;
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match tensors.get(name)? {
|
||||
|
@ -206,8 +206,8 @@ fn run_print(
|
|||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
use candle_core::safetensors::Load;
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||
use candle::safetensors::Load;
|
||||
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
||||
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
|
@ -221,7 +221,7 @@ fn run_print(
|
|||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let pth_file = candle_core::pickle::PthTensors::new(file, None)?;
|
||||
let pth_file = candle::pickle::PthTensors::new(file, None)?;
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match pth_file.get(name)? {
|
||||
|
@ -233,11 +233,11 @@ fn run_print(
|
|||
}
|
||||
}
|
||||
Format::Pickle => {
|
||||
candle_core::bail!("pickle format is not supported for print")
|
||||
candle::bail!("pickle format is not supported for print")
|
||||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
for name in names.iter() {
|
||||
println!("==== {name} ====");
|
||||
match content.tensors.get(name) {
|
||||
|
@ -287,7 +287,7 @@ fn run_ls(
|
|||
};
|
||||
match format {
|
||||
Format::Npz => {
|
||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||
let tensors = candle::npy::NpzTensors::new(file)?;
|
||||
let mut names = tensors.names();
|
||||
names.sort();
|
||||
for name in names {
|
||||
|
@ -299,12 +299,12 @@ fn run_ls(
|
|||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
||||
let mut tensors = tensors.tensors();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, view) in tensors.iter() {
|
||||
let dtype = view.dtype();
|
||||
let dtype = match candle_core::DType::try_from(dtype) {
|
||||
let dtype = match candle::DType::try_from(dtype) {
|
||||
Ok(dtype) => format!("{dtype:?}"),
|
||||
Err(_) => format!("{dtype:?}"),
|
||||
};
|
||||
|
@ -313,7 +313,7 @@ fn run_ls(
|
|||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
|
||||
let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?;
|
||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
for tensor_info in tensors.iter() {
|
||||
println!(
|
||||
|
@ -330,7 +330,7 @@ fn run_ls(
|
|||
Format::Pickle => {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut stack = candle_core::pickle::Stack::empty();
|
||||
let mut stack = candle::pickle::Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
for (i, obj) in stack.stack().iter().enumerate() {
|
||||
println!("{i} {obj:?}");
|
||||
|
@ -338,7 +338,7 @@ fn run_ls(
|
|||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, qtensor) in tensors.iter() {
|
||||
|
@ -374,7 +374,7 @@ fn run_quantize_safetensors(
|
|||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for in_file in in_files.iter() {
|
||||
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||
let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?;
|
||||
tensors.extend(in_tensors)
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
@ -416,7 +416,7 @@ fn run_dequantize(
|
|||
let tensor = tensor.dequantize(device)?;
|
||||
tensors.insert(tensor_name.to_string(), tensor);
|
||||
}
|
||||
candle_core::safetensors::save(&tensors, out_file)?;
|
||||
candle::safetensors::save(&tensors, out_file)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -428,11 +428,11 @@ fn run_quantize(
|
|||
device: &Device,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
candle::bail!("no specified input files")
|
||||
}
|
||||
if let Some(extension) = out_file.extension() {
|
||||
if extension == "safetensors" {
|
||||
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||
candle::bail!("the generated file cannot use the safetensors extension")
|
||||
}
|
||||
}
|
||||
if let Some(extension) = in_files[0].extension() {
|
||||
|
@ -442,7 +442,7 @@ fn run_quantize(
|
|||
}
|
||||
|
||||
if in_files.len() != 1 {
|
||||
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||
candle::bail!("only a single in-file can be used when quantizing gguf files")
|
||||
}
|
||||
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
Loading…
Reference in New Issue