Skip to content

Commit

Permalink
change parameters names
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Aug 28, 2023
1 parent b6029c0 commit a22661b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
20 changes: 11 additions & 9 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,17 @@ fn main() -> tract_core::anyhow::Result<()> {
.long_about("Run the graph")
.arg(Arg::new("dump").long("dump").help("Show output"))
.arg(
Arg::new("save-outputs")
.long("save-outputs")
Arg::new("save-outputs-npz")
.long("save-outputs-npz")
.alias("save-outputs")
.takes_value(true)
.help("Save the outputs into a npz file"),
)
.arg(
Arg::new("save-output-tensors")
.long("save-output-tensors")
Arg::new("save-outputs-nnef")
.long("save-outputs-nnef")
.takes_value(true)
.help("Save the output tensor into a .dat file"),
.help("Save the output tensor into a folder of nnef .dat files"),
)
.arg(Arg::new("steps").long("steps").help("Show all inputs and outputs"))
.arg(
Expand Down Expand Up @@ -401,14 +402,15 @@ fn run_options(command: clap::Command) -> clap::Command {
use clap::*;
command
.arg(
Arg::new("input-from-bundle")
.long("input-from-bundle")
Arg::new("input-from-npz")
.long("input-from-npz")
.alias("input-from-bundle")
.takes_value(true)
.help("Path to an input container (.npz). This sets tensor values."),
)
.arg(
Arg::new("input-from-tensors").long("input-from-tensors").takes_value(true).help(
"Path to a directory containing input tensors (.dat). This sets tensor values.",
Arg::new("input-from-nnef").long("input-from-nnef").takes_value(true).help(
"Path to a directory containing input tensors in NNEF format (.dat files). This sets tensor values.",
),
)
.arg(
Expand Down
4 changes: 2 additions & 2 deletions cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub fn handle(
}
}

if let Some(file_path) = sub_matches.value_of("save-output-tensors") {
if let Some(file_path) = sub_matches.value_of("save-outputs-nnef") {
std::fs::create_dir_all(file_path)
.with_context(|| format!("Creating {file_path} directory"))?;
for (ix, outputs) in outputs.iter().enumerate() {
Expand All @@ -83,7 +83,7 @@ pub fn handle(
}
}

if let Some(file_path) = sub_matches.value_of("save-outputs") {
if let Some(file_path) = sub_matches.value_of("save-outputs-npz") {
let file =
std::fs::File::create(file_path).with_context(|| format!("Creating {file_path}"))?;
let mut npz = ndarray_npy::NpzWriter::new_compressed(file);
Expand Down
4 changes: 2 additions & 2 deletions cli/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ pub fn run_params_from_subcommand(
) -> TractResult<RunParams> {
let mut tv = params.tensors_values.clone();

if let Some(bundle) = sub_matches.values_of("input-from-bundle") {
if let Some(bundle) = sub_matches.values_of("input-from-npz") {
for input in bundle {
for tensor in Parameters::parse_npz(input, true, false)? {
tv.add(tensor);
}
}
}

if let Some(dir) = sub_matches.value_of("input-from-tensors") {
if let Some(dir) = sub_matches.value_of("input-from-nnef") {
for tensor in Parameters::parse_nnef_tensors(dir, true, false)? {
tv.add(tensor);
}
Expand Down

0 comments on commit a22661b

Please sign in to comment.