diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index ad351171..0bda36d5 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -197,6 +197,11 @@ fn run_print( match format { Format::Npz => { let tensors = candle::npy::NpzTensors::new(file)?; + let names = if names.is_empty() { + tensors.names().into_iter().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match tensors.get(name)? { @@ -209,6 +214,11 @@ fn run_print( use candle::safetensors::Load; let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? }; let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect(); + let names = if names.is_empty() { + tensors.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match tensors.get(name) { @@ -222,6 +232,15 @@ fn run_print( } Format::Pth => { let pth_file = candle::pickle::PthTensors::new(file, None)?; + let names = if names.is_empty() { + pth_file + .tensor_infos() + .keys() + .map(|v| v.to_string()) + .collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match pth_file.get(name)? { @@ -238,6 +257,11 @@ fn run_print( Format::Ggml => { let mut file = std::fs::File::open(file)?; let content = candle::quantized::ggml_file::Content::read(&mut file, device)?; + let names = if names.is_empty() { + content.tensors.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match content.tensors.get(name) { @@ -252,6 +276,11 @@ fn run_print( Format::Gguf => { let mut file = std::fs::File::open(file)?; let content = gguf_file::Content::read(&mut file)?; + let names = if names.is_empty() { + content.tensor_infos.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match content.tensor(&mut file, name, device) {