Move the tensor-tools binary in a separate crate. (#1969)
This commit is contained in:
parent
b190fd8592
commit
3144150b8d
|
@ -9,6 +9,7 @@ members = [
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"candle-wasm-examples/*",
|
"candle-wasm-examples/*",
|
||||||
"candle-wasm-tests",
|
"candle-wasm-tests",
|
||||||
|
"tensor-tools",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
"candle-flash-attn",
|
"candle-flash-attn",
|
||||||
|
|
|
@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
|
||||||
`tensor-tools` command line utility via:
|
`tensor-tools` command line utility via:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using custom models
|
## Using custom models
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
[package]
|
||||||
|
name = "tensor-tools"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
description.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
candle = { workspace = true }
|
||||||
|
clap = { workspace = true }
|
||||||
|
rayon = { workspace = true }
|
||||||
|
safetensors = { workspace = true }
|
|
@ -1,5 +1,5 @@
|
||||||
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
use candle::quantized::{gguf_file, GgmlDType, QTensor};
|
||||||
use candle_core::{Device, Result};
|
use candle::{Device, Result};
|
||||||
use clap::{Parser, Subcommand, ValueEnum};
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
@ -177,10 +177,10 @@ fn run_print(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if full {
|
if full {
|
||||||
candle_core::display::set_print_options_full();
|
candle::display::set_print_options_full();
|
||||||
}
|
}
|
||||||
if let Some(line_width) = line_width {
|
if let Some(line_width) = line_width {
|
||||||
candle_core::display::set_line_width(line_width)
|
candle::display::set_line_width(line_width)
|
||||||
}
|
}
|
||||||
let format = match format {
|
let format = match format {
|
||||||
Some(format) => format,
|
Some(format) => format,
|
||||||
|
@ -196,7 +196,7 @@ fn run_print(
|
||||||
};
|
};
|
||||||
match format {
|
match format {
|
||||||
Format::Npz => {
|
Format::Npz => {
|
||||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
let tensors = candle::npy::NpzTensors::new(file)?;
|
||||||
for name in names.iter() {
|
for name in names.iter() {
|
||||||
println!("==== {name} ====");
|
println!("==== {name} ====");
|
||||||
match tensors.get(name)? {
|
match tensors.get(name)? {
|
||||||
|
@ -206,8 +206,8 @@ fn run_print(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Safetensors => {
|
Format::Safetensors => {
|
||||||
use candle_core::safetensors::Load;
|
use candle::safetensors::Load;
|
||||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
||||||
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
|
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
|
||||||
for name in names.iter() {
|
for name in names.iter() {
|
||||||
println!("==== {name} ====");
|
println!("==== {name} ====");
|
||||||
|
@ -221,7 +221,7 @@ fn run_print(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Pth => {
|
Format::Pth => {
|
||||||
let pth_file = candle_core::pickle::PthTensors::new(file, None)?;
|
let pth_file = candle::pickle::PthTensors::new(file, None)?;
|
||||||
for name in names.iter() {
|
for name in names.iter() {
|
||||||
println!("==== {name} ====");
|
println!("==== {name} ====");
|
||||||
match pth_file.get(name)? {
|
match pth_file.get(name)? {
|
||||||
|
@ -233,11 +233,11 @@ fn run_print(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Pickle => {
|
Format::Pickle => {
|
||||||
candle_core::bail!("pickle format is not supported for print")
|
candle::bail!("pickle format is not supported for print")
|
||||||
}
|
}
|
||||||
Format::Ggml => {
|
Format::Ggml => {
|
||||||
let mut file = std::fs::File::open(file)?;
|
let mut file = std::fs::File::open(file)?;
|
||||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||||
for name in names.iter() {
|
for name in names.iter() {
|
||||||
println!("==== {name} ====");
|
println!("==== {name} ====");
|
||||||
match content.tensors.get(name) {
|
match content.tensors.get(name) {
|
||||||
|
@ -287,7 +287,7 @@ fn run_ls(
|
||||||
};
|
};
|
||||||
match format {
|
match format {
|
||||||
Format::Npz => {
|
Format::Npz => {
|
||||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
let tensors = candle::npy::NpzTensors::new(file)?;
|
||||||
let mut names = tensors.names();
|
let mut names = tensors.names();
|
||||||
names.sort();
|
names.sort();
|
||||||
for name in names {
|
for name in names {
|
||||||
|
@ -299,12 +299,12 @@ fn run_ls(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Safetensors => {
|
Format::Safetensors => {
|
||||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
||||||
let mut tensors = tensors.tensors();
|
let mut tensors = tensors.tensors();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, view) in tensors.iter() {
|
for (name, view) in tensors.iter() {
|
||||||
let dtype = view.dtype();
|
let dtype = view.dtype();
|
||||||
let dtype = match candle_core::DType::try_from(dtype) {
|
let dtype = match candle::DType::try_from(dtype) {
|
||||||
Ok(dtype) => format!("{dtype:?}"),
|
Ok(dtype) => format!("{dtype:?}"),
|
||||||
Err(_) => format!("{dtype:?}"),
|
Err(_) => format!("{dtype:?}"),
|
||||||
};
|
};
|
||||||
|
@ -313,7 +313,7 @@ fn run_ls(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Pth => {
|
Format::Pth => {
|
||||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
|
let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?;
|
||||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
for tensor_info in tensors.iter() {
|
for tensor_info in tensors.iter() {
|
||||||
println!(
|
println!(
|
||||||
|
@ -330,7 +330,7 @@ fn run_ls(
|
||||||
Format::Pickle => {
|
Format::Pickle => {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let mut reader = std::io::BufReader::new(file);
|
let mut reader = std::io::BufReader::new(file);
|
||||||
let mut stack = candle_core::pickle::Stack::empty();
|
let mut stack = candle::pickle::Stack::empty();
|
||||||
stack.read_loop(&mut reader)?;
|
stack.read_loop(&mut reader)?;
|
||||||
for (i, obj) in stack.stack().iter().enumerate() {
|
for (i, obj) in stack.stack().iter().enumerate() {
|
||||||
println!("{i} {obj:?}");
|
println!("{i} {obj:?}");
|
||||||
|
@ -338,7 +338,7 @@ fn run_ls(
|
||||||
}
|
}
|
||||||
Format::Ggml => {
|
Format::Ggml => {
|
||||||
let mut file = std::fs::File::open(file)?;
|
let mut file = std::fs::File::open(file)?;
|
||||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, qtensor) in tensors.iter() {
|
for (name, qtensor) in tensors.iter() {
|
||||||
|
@ -374,7 +374,7 @@ fn run_quantize_safetensors(
|
||||||
let mut out_file = std::fs::File::create(out_file)?;
|
let mut out_file = std::fs::File::create(out_file)?;
|
||||||
let mut tensors = std::collections::HashMap::new();
|
let mut tensors = std::collections::HashMap::new();
|
||||||
for in_file in in_files.iter() {
|
for in_file in in_files.iter() {
|
||||||
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?;
|
||||||
tensors.extend(in_tensors)
|
tensors.extend(in_tensors)
|
||||||
}
|
}
|
||||||
println!("tensors: {}", tensors.len());
|
println!("tensors: {}", tensors.len());
|
||||||
|
@ -416,7 +416,7 @@ fn run_dequantize(
|
||||||
let tensor = tensor.dequantize(device)?;
|
let tensor = tensor.dequantize(device)?;
|
||||||
tensors.insert(tensor_name.to_string(), tensor);
|
tensors.insert(tensor_name.to_string(), tensor);
|
||||||
}
|
}
|
||||||
candle_core::safetensors::save(&tensors, out_file)?;
|
candle::safetensors::save(&tensors, out_file)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -428,11 +428,11 @@ fn run_quantize(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if in_files.is_empty() {
|
if in_files.is_empty() {
|
||||||
candle_core::bail!("no specified input files")
|
candle::bail!("no specified input files")
|
||||||
}
|
}
|
||||||
if let Some(extension) = out_file.extension() {
|
if let Some(extension) = out_file.extension() {
|
||||||
if extension == "safetensors" {
|
if extension == "safetensors" {
|
||||||
candle_core::bail!("the generated file cannot use the safetensors extension")
|
candle::bail!("the generated file cannot use the safetensors extension")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(extension) = in_files[0].extension() {
|
if let Some(extension) = in_files[0].extension() {
|
||||||
|
@ -442,7 +442,7 @@ fn run_quantize(
|
||||||
}
|
}
|
||||||
|
|
||||||
if in_files.len() != 1 {
|
if in_files.len() != 1 {
|
||||||
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
candle::bail!("only a single in-file can be used when quantizing gguf files")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open the out file early so as to fail directly on missing directories etc.
|
// Open the out file early so as to fail directly on missing directories etc.
|
Loading…
Reference in New Issue