diff --git a/py/tools/py/src/pth.rs b/py/tools/py/src/pth.rs index 16dde603..6e965caf 100644 --- a/py/tools/py/src/pth.rs +++ b/py/tools/py/src/pth.rs @@ -1,6 +1,6 @@ use std::{ - fs::{self, File}, - io::{BufRead, BufReader, BufWriter, Write}, + fs::{self, DirEntry, File}, + io::{BufRead, BufReader, BufWriter, Read, Write}, path::{Path, PathBuf}, }; @@ -19,27 +19,11 @@ impl PthFile { } } - pub fn copy_to_site_packages(&self, dest: &Path) -> miette::Result<()> { - let dest_pth = dest.join(self.src.file_name().expect(".pth must be a file")); - - if let Some(prefix) = self.prefix.as_deref() { - self.prefix_and_write_pth(dest_pth, prefix) - } else { - fs::copy(self.src.as_path(), dest_pth) - .map(|_| ()) - .into_diagnostic() - .wrap_err("Unable to copy .pth file to site-packages") - } - } - - fn prefix_and_write_pth

(&self, dest: P, prefix: &str) -> miette::Result<()> - where - P: AsRef, - { + pub fn set_up_site_packages(&self, dest: &Path) -> miette::Result<()> { let source_pth = File::open(self.src.as_path()) .into_diagnostic() .wrap_err("Unable to open source .pth file")?; - let dest_pth = File::create(dest) + let dest_pth = File::create(dest.join(self.src.file_name().expect(".pth must be a file"))) .into_diagnostic() .wrap_err("Unable to create destination .pth file")?; @@ -47,14 +31,121 @@ impl PthFile { let mut writer = BufWriter::new(dest_pth); let mut line = String::new(); + let path_prefix = self.prefix.as_ref().map(|pre| Path::new(pre)); + while reader.read_line(&mut line).unwrap() > 0 { - let entry = Path::new(prefix).join(Path::new(line.trim())); + let entry = path_prefix + .map(|pre| pre.join(line.trim())) + .unwrap_or_else(|| PathBuf::from(line.trim())); + line.clear(); - writeln!(writer, "{}", entry.to_string_lossy()) - .into_diagnostic() - .wrap_err("Unable to write new .pth file entry with prefix")?; + + match entry.file_name() { + Some(name) if name == "site-packages" => { + let src_dir = dest + .join(entry) + .canonicalize() + .into_diagnostic() + .wrap_err("Unable to get full source dir path")?; + create_symlinks(&src_dir, &src_dir, &dest)?; + } + _ => { + writeln!(writer, "{}", entry.to_string_lossy()) + .into_diagnostic() + .wrap_err("Unable to write new .pth file entry")?; + } + } } Ok(()) } } + +fn create_symlinks(dir: &Path, root_dir: &Path, dst_dir: &Path) -> miette::Result<()> { + // Create this directory at the destination. + let tgt_dir = dst_dir.join(dir.strip_prefix(root_dir).unwrap()); + std::fs::create_dir_all(&tgt_dir) + .into_diagnostic() + .wrap_err(format!( + "Unable to create parent directory for symlink: {}", + tgt_dir.to_string_lossy() + ))?; + + // Recurse. + let read_dir = fs::read_dir(dir).into_diagnostic().wrap_err(format!( + "Unable to read directory {}", + dir.to_string_lossy() + ))?; + + for entry in read_dir { + let entry = entry.into_diagnostic().wrap_err(format!( + "Unable to read directory entry {}", + dir.to_string_lossy() + ))?; + + let path = entry.path(); + + // If this path is a directory, recurse into it, else symlink the file now. + // We must ignore the `__init__.py` file in the root_dir because these are Bazel inserted + // `__init__.py` files in the root site-packages directory. The site-packages directory + // itself is not a regular package and is not supposed to have an `__init__.py` file. + if path.is_dir() { + create_symlinks(&path, root_dir, dst_dir)?; + } else if dir != root_dir || entry.file_name() != "__init__.py" { + create_symlink(&entry, root_dir, dst_dir)?; + } + } + Ok(()) +} + +fn create_symlink(e: &DirEntry, root_dir: &Path, dst_dir: &Path) -> miette::Result<()> { + let tgt = e.path(); + let link = dst_dir.join(tgt.strip_prefix(root_dir).unwrap()); + + // If the link already exists, do not return an error if the link is for an `__init__.py` file + // with the same content as the new destination. Some packages that should ideally be namespace + // packages have copies of `__init__.py` files in their distributions. For example, all the + // Nvidia PyPI packages have the same `nvidia/__init__.py`. So we need to either overwrite the + // previous symlink, or check that the new location also has the same content. + if link.exists() + && link.file_name().is_some_and(|x| x == "__init__.py") + && is_same_file(link.as_path(), tgt.as_path())? + { + return Ok(()); + } + + std::os::unix::fs::symlink(&tgt, &link) + .into_diagnostic() + .wrap_err(format!( + "unable to create symlink: {} -> {}", + tgt.to_string_lossy(), + link.to_string_lossy() + ))?; + + Ok(()) +} + +fn is_same_file(p1: &Path, p2: &Path) -> miette::Result { + let f1 = File::open(p1) + .into_diagnostic() + .wrap_err(format!("Unable to open file {}", p1.to_string_lossy()))?; + let f2 = File::open(p2) + .into_diagnostic() + .wrap_err(format!("Unable to open file {}", p2.to_string_lossy()))?; + + // Check file size is the same. + if f1.metadata().unwrap().len() != f2.metadata().unwrap().len() { + return Ok(false); + } + + // Compare bytes from the two files in pairs, given that they have the same number of bytes. + let buf1 = BufReader::new(f1); + let buf2 = BufReader::new(f2); + for (b1, b2) in buf1.bytes().zip(buf2.bytes()) { + if b1.unwrap() != b2.unwrap() { + return Ok(false); + } + } + + return Ok(true); +} diff --git a/py/tools/py/src/venv.rs b/py/tools/py/src/venv.rs index 6e6e65fb..14704275 100644 --- a/py/tools/py/src/venv.rs +++ b/py/tools/py/src/venv.rs @@ -51,7 +51,7 @@ pub fn create_venv( .into_diagnostic()?; if let Some(pth) = pth_file { - pth.copy_to_site_packages(&venv_location.join(install_paths.platlib()))? + pth.set_up_site_packages(&venv_location.join(install_paths.platlib()))? } Ok(())