diff --git a/py/tools/py/src/pth.rs b/py/tools/py/src/pth.rs
index 16dde603..6e965caf 100644
--- a/py/tools/py/src/pth.rs
+++ b/py/tools/py/src/pth.rs
@@ -1,6 +1,6 @@
use std::{
- fs::{self, File},
- io::{BufRead, BufReader, BufWriter, Write},
+ fs::{self, DirEntry, File},
+ io::{BufRead, BufReader, BufWriter, Read, Write},
path::{Path, PathBuf},
};
@@ -19,27 +19,11 @@ impl PthFile {
}
}
- pub fn copy_to_site_packages(&self, dest: &Path) -> miette::Result<()> {
- let dest_pth = dest.join(self.src.file_name().expect(".pth must be a file"));
-
- if let Some(prefix) = self.prefix.as_deref() {
- self.prefix_and_write_pth(dest_pth, prefix)
- } else {
- fs::copy(self.src.as_path(), dest_pth)
- .map(|_| ())
- .into_diagnostic()
- .wrap_err("Unable to copy .pth file to site-packages")
- }
- }
-
- fn prefix_and_write_pth
(&self, dest: P, prefix: &str) -> miette::Result<()>
- where
- P: AsRef,
- {
+ pub fn set_up_site_packages(&self, dest: &Path) -> miette::Result<()> {
let source_pth = File::open(self.src.as_path())
.into_diagnostic()
.wrap_err("Unable to open source .pth file")?;
- let dest_pth = File::create(dest)
+ let dest_pth = File::create(dest.join(self.src.file_name().expect(".pth must be a file")))
.into_diagnostic()
.wrap_err("Unable to create destination .pth file")?;
@@ -47,14 +31,121 @@ impl PthFile {
let mut writer = BufWriter::new(dest_pth);
let mut line = String::new();
+ let path_prefix = self.prefix.as_ref().map(|pre| Path::new(pre));
+
while reader.read_line(&mut line).unwrap() > 0 {
- let entry = Path::new(prefix).join(Path::new(line.trim()));
+ let entry = path_prefix
+ .map(|pre| pre.join(line.trim()))
+ .unwrap_or_else(|| PathBuf::from(line.trim()));
+
line.clear();
- writeln!(writer, "{}", entry.to_string_lossy())
- .into_diagnostic()
- .wrap_err("Unable to write new .pth file entry with prefix")?;
+
+ match entry.file_name() {
+ Some(name) if name == "site-packages" => {
+ let src_dir = dest
+ .join(entry)
+ .canonicalize()
+ .into_diagnostic()
+ .wrap_err("Unable to get full source dir path")?;
+ create_symlinks(&src_dir, &src_dir, &dest)?;
+ }
+ _ => {
+ writeln!(writer, "{}", entry.to_string_lossy())
+ .into_diagnostic()
+ .wrap_err("Unable to write new .pth file entry")?;
+ }
+ }
}
Ok(())
}
}
+
+fn create_symlinks(dir: &Path, root_dir: &Path, dst_dir: &Path) -> miette::Result<()> {
+ // Create this directory at the destination.
+ let tgt_dir = dst_dir.join(dir.strip_prefix(root_dir).unwrap());
+ std::fs::create_dir_all(&tgt_dir)
+ .into_diagnostic()
+ .wrap_err(format!(
+ "Unable to create parent directory for symlink: {}",
+ tgt_dir.to_string_lossy()
+ ))?;
+
+ // Recurse.
+ let read_dir = fs::read_dir(dir).into_diagnostic().wrap_err(format!(
+ "Unable to read directory {}",
+ dir.to_string_lossy()
+ ))?;
+
+ for entry in read_dir {
+ let entry = entry.into_diagnostic().wrap_err(format!(
+ "Unable to read directory entry {}",
+ dir.to_string_lossy()
+ ))?;
+
+ let path = entry.path();
+
+ // If this path is a directory, recurse into it, else symlink the file now.
+ // We must ignore the `__init__.py` file in the root_dir because these are Bazel inserted
+ // `__init__.py` files in the root site-packages directory. The site-packages directory
+ // itself is not a regular package and is not supposed to have an `__init__.py` file.
+ if path.is_dir() {
+ create_symlinks(&path, root_dir, dst_dir)?;
+ } else if dir != root_dir || entry.file_name() != "__init__.py" {
+ create_symlink(&entry, root_dir, dst_dir)?;
+ }
+ }
+ Ok(())
+}
+
+fn create_symlink(e: &DirEntry, root_dir: &Path, dst_dir: &Path) -> miette::Result<()> {
+ let tgt = e.path();
+ let link = dst_dir.join(tgt.strip_prefix(root_dir).unwrap());
+
+ // If the link already exists, do not return an error if the link is for an `__init__.py` file
+ // with the same content as the new destination. Some packages that should ideally be namespace
+ // packages have copies of `__init__.py` files in their distributions. For example, all the
+ // Nvidia PyPI packages have the same `nvidia/__init__.py`. So we need to either overwrite the
+ // previous symlink, or check that the new location also has the same content.
+ if link.exists()
+ && link.file_name().is_some_and(|x| x == "__init__.py")
+ && is_same_file(link.as_path(), tgt.as_path())?
+ {
+ return Ok(());
+ }
+
+ std::os::unix::fs::symlink(&tgt, &link)
+ .into_diagnostic()
+ .wrap_err(format!(
+ "unable to create symlink: {} -> {}",
+ tgt.to_string_lossy(),
+ link.to_string_lossy()
+ ))?;
+
+ Ok(())
+}
+
+fn is_same_file(p1: &Path, p2: &Path) -> miette::Result {
+ let f1 = File::open(p1)
+ .into_diagnostic()
+ .wrap_err(format!("Unable to open file {}", p1.to_string_lossy()))?;
+ let f2 = File::open(p2)
+ .into_diagnostic()
+ .wrap_err(format!("Unable to open file {}", p2.to_string_lossy()))?;
+
+ // Check file size is the same.
+ if f1.metadata().unwrap().len() != f2.metadata().unwrap().len() {
+ return Ok(false);
+ }
+
+ // Compare bytes from the two files in pairs, given that they have the same number of bytes.
+ let buf1 = BufReader::new(f1);
+ let buf2 = BufReader::new(f2);
+ for (b1, b2) in buf1.bytes().zip(buf2.bytes()) {
+ if b1.unwrap() != b2.unwrap() {
+ return Ok(false);
+ }
+ }
+
+ return Ok(true);
+}
diff --git a/py/tools/py/src/venv.rs b/py/tools/py/src/venv.rs
index 6e6e65fb..14704275 100644
--- a/py/tools/py/src/venv.rs
+++ b/py/tools/py/src/venv.rs
@@ -51,7 +51,7 @@ pub fn create_venv(
.into_diagnostic()?;
if let Some(pth) = pth_file {
- pth.copy_to_site_packages(&venv_location.join(install_paths.platlib()))?
+ pth.set_up_site_packages(&venv_location.join(install_paths.platlib()))?
}
Ok(())