From 02eb0f61f06a42ff43f3bc944577146d28dfd423 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 09:01:31 +0100 Subject: [PATCH 01/19] Implement stub GraphBuilder --- rust/src/graph/builder/mod.rs | 59 +++++++++++++++++++++++++++++++++++ rust/src/graph/mod.rs | 1 + 2 files changed, 60 insertions(+) create mode 100644 rust/src/graph/builder/mod.rs diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs new file mode 100644 index 00000000..5ddb44f4 --- /dev/null +++ b/rust/src/graph/builder/mod.rs @@ -0,0 +1,59 @@ +use std::path::PathBuf; + +use derive_new::new; + +use crate::graph::Graph; + +#[derive(Debug, Clone, new)] +pub struct PackageSpec { + name: String, + directory: PathBuf, +} + +#[derive(Debug, Clone)] +pub struct GraphBuilder { + package: PackageSpec, // TODO(peter) Support multiple packages + include_external_packages: bool, + exclude_type_checking_imports: bool, +} + +impl GraphBuilder { + pub fn new(package: PackageSpec) -> Self { + GraphBuilder { + package, + include_external_packages: false, + exclude_type_checking_imports: false, + } + } + + pub fn include_external_packages(mut self, yes: bool) -> Self { + self.include_external_packages = yes; + self + } + + pub fn exclude_type_checking_imports(mut self, yes: bool) -> Self { + self.exclude_type_checking_imports = yes; + self + } + + pub fn build(&self) -> Graph { + todo!() + // 1. Find all python files in the package. + // Use the `ignore` crate. + // + // 2. For each python file in the package, parse the imports. + // Use the existing `parse_imports_from_code` function. + // + // 3. For each python file in the package, resolve the imports. + // You can reuse the existing logic in `scan_for_imports_no_py_single_module`, + // but not directly. Copy the minimum, necessary code over to a new module + // here in `rust/src/graph/builder/`. + // + // 4. Assemble the graph. Copy logic from the python implementation `_assemble_graph`. + // + // 5. Create a python usecase in `src/grimp/application/usecases.py` called `build_graph_rust`. + // + // * Do not do any parallelization yet. + // * Do not do any caching yet. + } +} diff --git a/rust/src/graph/mod.rs b/rust/src/graph/mod.rs index b01e82b7..a787a650 100644 --- a/rust/src/graph/mod.rs +++ b/rust/src/graph/mod.rs @@ -19,6 +19,7 @@ use crate::graph::higher_order_queries::Level; use crate::graph::higher_order_queries::PackageDependency as PyPackageDependency; use crate::module_expressions::ModuleExpression; +pub mod builder; pub mod direct_import_queries; pub mod graph_manipulation; pub mod hierarchy_queries; From f6c351e4a84455a526a0dae0451a827407850376 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 10:27:18 +0100 Subject: [PATCH 02/19] Add build_graph_rust usecase --- .importlinter | 1 + rust/Cargo.lock | 73 +++++++ rust/Cargo.toml | 1 + rust/src/graph/builder/mod.rs | 198 ++++++++++++++++-- rust/src/graph/builder/utils.rs | 117 +++++++++++ rust/src/graph/mod.rs | 4 + rust/src/graph_building.rs | 54 +++++ rust/src/lib.rs | 4 + src/grimp/__init__.py | 7 +- src/grimp/application/graph.py | 10 +- src/grimp/application/usecases.py | 43 +++- src/grimp/main.py | 6 +- .../test_build_graph_on_real_packages.py | 5 + 13 files changed, 492 insertions(+), 31 deletions(-) create mode 100644 rust/src/graph/builder/utils.rs create mode 100644 rust/src/graph_building.rs diff --git a/.importlinter b/.importlinter index c58f0d92..097c591d 100644 --- a/.importlinter +++ b/.importlinter @@ -21,3 +21,4 @@ ignore_imports = grimp.application.graph -> grimp grimp.adaptors.filesystem -> grimp grimp.application.scanning -> grimp + grimp.application.usecases -> grimp diff --git a/rust/Cargo.lock b/rust/Cargo.lock index a00f0b1b..89dd88da 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -11,6 +11,7 @@ dependencies = [ "derive-new", "encoding_rs", "getset", + "ignore", "indexmap 2.11.0", "itertools 0.14.0", "parameterized", @@ -190,6 +191,19 @@ dependencies = [ "syn", ] +[[package]] +name = "globset" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52dfc19153a48bde0cbd630453615c8151bce3a5adfac7a0aebfbf0a1e1f57e3" +dependencies = [ + "aho-corasick", + "bstr", + "log", + "regex-automata", + "regex-syntax", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -211,6 +225,22 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "ignore" +version = "0.4.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3d782a365a015e0f5c04902246139249abf769125006fbe7649e2ee88169b4a" +dependencies = [ + "crossbeam-deque", + "globset", + "log", + "memchr", + "regex-automata", + "same-file", + "walkdir", + "winapi-util", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -638,6 +668,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "serde" version = "1.0.219" @@ -839,12 +878,46 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + [[package]] name = "zerocopy" version = "0.8.26" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 4eb4e139..f81a2665 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -29,6 +29,7 @@ serde_json = "1.0.137" serde_yaml = "0.9" unindent = "0.2.4" encoding_rs = "0.8.35" +ignore = "0.4" [dependencies.pyo3] version = "0.26.0" diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 5ddb44f4..dd549e37 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -1,8 +1,18 @@ +use std::collections::{HashMap, HashSet}; +use std::fs; use std::path::PathBuf; use derive_new::new; +use ignore::WalkBuilder; use crate::graph::Graph; +use crate::import_parsing::{ImportedObject, parse_imports_from_code}; + +mod utils; +use utils::{ + ResolvedImport, is_internal, is_package, path_to_module_name, resolve_external_module, + resolve_internal_module, resolve_relative_import, +}; #[derive(Debug, Clone, new)] pub struct PackageSpec { @@ -15,6 +25,7 @@ pub struct GraphBuilder { package: PackageSpec, // TODO(peter) Support multiple packages include_external_packages: bool, exclude_type_checking_imports: bool, + // cache_dir: Option // TODO(peter) } impl GraphBuilder { @@ -37,23 +48,174 @@ impl GraphBuilder { } pub fn build(&self) -> Graph { - todo!() - // 1. Find all python files in the package. - // Use the `ignore` crate. - // - // 2. For each python file in the package, parse the imports. - // Use the existing `parse_imports_from_code` function. - // - // 3. For each python file in the package, resolve the imports. - // You can reuse the existing logic in `scan_for_imports_no_py_single_module`, - // but not directly. Copy the minimum, necessary code over to a new module - // here in `rust/src/graph/builder/`. - // - // 4. Assemble the graph. Copy logic from the python implementation `_assemble_graph`. - // - // 5. Create a python usecase in `src/grimp/application/usecases.py` called `build_graph_rust`. - // - // * Do not do any parallelization yet. - // * Do not do any caching yet. + let modules = discover_python_modules(&self.package); + + let parsed_modules = parse_imports(&modules); + + let imports_by_module = resolve_imports( + &parsed_modules, + self.include_external_packages, + self.exclude_type_checking_imports, + ); + + assemble_graph(&imports_by_module, &self.package.name) } } + +#[derive(Debug, Clone)] +struct FoundModule { + name: String, + path: PathBuf, + is_package: bool, +} + +#[derive(Debug)] +struct ParsedModule { + module: FoundModule, + imported_objects: Vec, +} + +fn discover_python_modules(package: &PackageSpec) -> Vec { + let mut modules = Vec::new(); + + let walker = WalkBuilder::new(&package.directory) + .standard_filters(false) // Don't use gitignore or other filters + .filter_entry(|entry| { + // Allow Python files + if entry.file_type().is_some_and(|ft| ft.is_file()) { + return entry.path().extension().and_then(|s| s.to_str()) == Some("py"); + } + + // For directories, only descend if they contain __init__.py + if entry.file_type().is_some_and(|ft| ft.is_dir()) { + let init_path = entry.path().join("__init__.py"); + return init_path.exists(); + } + + false + }) + .build(); + + for entry in walker.flatten() { + let path = entry.path(); + if let Some(module_name) = path_to_module_name(path, package) { + let is_package = is_package(path); + modules.push(FoundModule { + name: module_name, + path: path.to_owned(), + is_package, + }); + } + } + + modules +} + +fn parse_imports(modules: &[FoundModule]) -> Vec { + let mut parsed_modules = Vec::new(); + + for module in modules { + // Read the file + if let Ok(code) = fs::read_to_string(&module.path) { + // Parse imports + if let Ok(imported_objects) = + parse_imports_from_code(&code, module.path.to_str().unwrap_or("")) + { + parsed_modules.push(ParsedModule { + module: module.clone(), + imported_objects, + }); + } + } + } + + parsed_modules +} + +fn resolve_imports( + parsed_modules: &[ParsedModule], + include_external_packages: bool, + exclude_type_checking_imports: bool, +) -> HashMap> { + let all_modules: HashSet = parsed_modules + .iter() + .map(|module| module.module.name.clone()) + .collect(); + + let mut imports_by_module = HashMap::new(); + for parsed_module in parsed_modules { + let mut resolved_imports = HashSet::new(); + + // Resolve each imported object + for imported_object in &parsed_module.imported_objects { + // Skip type checking imports if requested + if exclude_type_checking_imports && imported_object.typechecking_only { + continue; + } + + // Resolve relative imports to absolute + let absolute_import_name = resolve_relative_import( + &parsed_module.module.name, + parsed_module.module.is_package, + &imported_object.name, + ); + + // Try to resolve as internal module first + if let Some(internal_module) = + resolve_internal_module(&absolute_import_name, &all_modules) + { + resolved_imports.insert(ResolvedImport { + importer: parsed_module.module.name.to_string(), + imported: internal_module, + line_number: imported_object.line_number, + line_contents: imported_object.line_contents.clone(), + }); + } else if include_external_packages { + // It's an external module and we're including them + let external_module = resolve_external_module(&absolute_import_name); + resolved_imports.insert(ResolvedImport { + importer: parsed_module.module.name.to_string(), + imported: external_module, + line_number: imported_object.line_number, + line_contents: imported_object.line_contents.clone(), + }); + } + } + + imports_by_module.insert(parsed_module.module.name.clone(), resolved_imports); + } + + imports_by_module +} + +fn assemble_graph( + imports_by_module: &HashMap>, + package_name: &str, +) -> Graph { + let mut graph = Graph::default(); + + // Add all modules and their imports + for (module_name, imports) in imports_by_module { + // Add the module itself and get its token + let importer_token = graph.get_or_add_module(module_name).token(); + + for import in imports { + // Add the imported module + let imported_token = if is_internal(&import.imported, package_name) { + graph.get_or_add_module(&import.imported).token() + } else { + graph.get_or_add_squashed_module(&import.imported).token() + }; + + // Add the import with details + graph.add_detailed_import( + importer_token, + imported_token, + import.line_number as u32, + &import.line_contents, + ); + } + } + + graph +} diff --git a/rust/src/graph/builder/utils.rs b/rust/src/graph/builder/utils.rs new file mode 100644 index 00000000..cd8c67f9 --- /dev/null +++ b/rust/src/graph/builder/utils.rs @@ -0,0 +1,117 @@ +use std::collections::HashSet; +use std::path::Path; + +use crate::graph::builder::PackageSpec; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub struct ResolvedImport { + pub importer: String, + pub imported: String, + pub line_number: usize, + pub line_contents: String, +} + +/// Check if a module filename represents a package (i.e., __init__.py) +pub fn is_package(module_path: &Path) -> bool { + module_path + .file_name() + .and_then(|name| name.to_str()) + .map(|name| name == "__init__.py") + .unwrap_or(false) +} + +/// Check if module is internal +pub fn is_internal(module_name: &str, package: &str) -> bool { + if module_name == package || module_name.starts_with(&format!("{}.", package)) { + return true; + } + false +} + +/// Convert module path to module name +pub fn path_to_module_name(module_path: &Path, package: &PackageSpec) -> Option { + let relative_path = module_path.strip_prefix(&package.directory).ok()?; + + let mut components: Vec = vec![package.name.clone()]; + for component in relative_path.iter() { + let component_str = component.to_str()?; + if component_str == "__init__.py" { + // This is a package, don't add __init__ + break; + } else if component_str.ends_with(".py") { + // Strip .py extension + components.push(component_str.strip_suffix(".py")?.to_string()); + } else { + // Directory component + components.push(component_str.to_string()); + } + } + + Some(components.join(".")) +} + +/// Convert a relative import to an absolute import name +pub fn resolve_relative_import( + module_name: &str, + is_package: bool, + imported_object_name: &str, +) -> String { + let leading_dots_count = imported_object_name + .chars() + .take_while(|&c| c == '.') + .count(); + + if leading_dots_count == 0 { + return imported_object_name.to_string(); + } + + let imported_object_name_base = if is_package { + if leading_dots_count == 1 { + module_name.to_string() + } else { + let parts: Vec<&str> = module_name.split('.').collect(); + parts[0..parts.len() - leading_dots_count + 1].join(".") + } + } else { + let parts: Vec<&str> = module_name.split('.').collect(); + parts[0..parts.len() - leading_dots_count].join(".") + }; + + format!( + "{}.{}", + imported_object_name_base, + &imported_object_name[leading_dots_count..] + ) +} + +/// Resolve an imported object name to an internal module +pub fn resolve_internal_module( + imported_object_name: &str, + all_modules: &HashSet, +) -> Option { + let candidate_module = imported_object_name.to_string(); + + if all_modules.contains(&candidate_module) { + return Some(candidate_module); + } + + // Check if parent module exists + if let Some((parent, _)) = imported_object_name.rsplit_once('.') + && all_modules.contains(parent) + { + return Some(parent.to_string()); + } + + None +} + +/// Get external module name +pub fn resolve_external_module(module_name: &str) -> String { + // For simplicity, just return the root module for external imports + // This matches the basic behavior from _distill_external_module + module_name + .split('.') + .next() + .unwrap_or(module_name) + .to_string() +} diff --git a/rust/src/graph/mod.rs b/rust/src/graph/mod.rs index a787a650..4dda54e0 100644 --- a/rust/src/graph/mod.rs +++ b/rust/src/graph/mod.rs @@ -82,6 +82,10 @@ pub struct GraphWrapper { } impl GraphWrapper { + pub fn from_graph(graph: Graph) -> Self { + GraphWrapper { _graph: graph } + } + fn get_visible_module_by_name(&self, name: &str) -> Result<&Module, ModuleNotPresent> { self._graph .get_module_by_name(name) diff --git a/rust/src/graph_building.rs b/rust/src/graph_building.rs new file mode 100644 index 00000000..8767338e --- /dev/null +++ b/rust/src/graph_building.rs @@ -0,0 +1,54 @@ +use pyo3::prelude::*; +use std::path::PathBuf; + +use crate::graph::GraphWrapper; +use crate::graph::builder::{GraphBuilder, PackageSpec}; + +#[pyclass(name = "PackageSpec")] +#[derive(Clone)] +pub struct PyPackageSpec { + inner: PackageSpec, +} + +#[pymethods] +impl PyPackageSpec { + #[new] + fn new(name: String, directory: String) -> Self { + PyPackageSpec { + inner: PackageSpec::new(name, PathBuf::from(directory)), + } + } +} + +#[pyclass(name = "GraphBuilder")] +pub struct PyGraphBuilder { + inner: GraphBuilder, +} + +#[pymethods] +impl PyGraphBuilder { + #[new] + fn new(package: PyPackageSpec) -> Self { + PyGraphBuilder { + inner: GraphBuilder::new(package.inner), + } + } + + fn include_external_packages(mut self_: PyRefMut<'_, Self>, yes: bool) -> PyRefMut<'_, Self> { + self_.inner = self_.inner.clone().include_external_packages(yes); + self_ + } + + fn exclude_type_checking_imports( + mut self_: PyRefMut<'_, Self>, + yes: bool, + ) -> PyRefMut<'_, Self> { + self_.inner = self_.inner.clone().exclude_type_checking_imports(yes); + self_ + } + + fn build(&self) -> GraphWrapper { + let graph = self.inner.build(); + GraphWrapper::from_graph(graph) + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index e63b501b..3b211383 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -3,6 +3,7 @@ pub mod errors; pub mod exceptions; mod filesystem; pub mod graph; +mod graph_building; pub mod import_parsing; mod import_scanning; pub mod module_expressions; @@ -31,4 +32,7 @@ mod _rustgrimp { use crate::exceptions::{ CorruptCache, InvalidModuleExpression, ModuleNotPresent, NoSuchContainer, ParseError, }; + + #[pymodule_export] + use crate::graph_building::{PyGraphBuilder, PyPackageSpec}; } diff --git a/src/grimp/__init__.py b/src/grimp/__init__.py index 26fa6298..90c6ac03 100644 --- a/src/grimp/__init__.py +++ b/src/grimp/__init__.py @@ -1,9 +1,9 @@ __version__ = "3.13" -from .application.graph import DetailedImport, ImportGraph, Import +from .application.graph import DetailedImport, Import, ImportGraph from .domain.analysis import PackageDependency, Route -from .domain.valueobjects import DirectImport, Module, Layer -from .main import build_graph +from .domain.valueobjects import DirectImport, Layer, Module +from .main import build_graph, build_graph_rust __all__ = [ "Module", @@ -14,5 +14,6 @@ "PackageDependency", "Route", "build_graph", + "build_graph_rust", "Layer", ] diff --git a/src/grimp/application/graph.py b/src/grimp/application/graph.py index 0847511c..b6350c87 100644 --- a/src/grimp/application/graph.py +++ b/src/grimp/application/graph.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import TypedDict + from collections.abc import Sequence +from typing import TypedDict + +from grimp import _rustgrimp as rust # type: ignore[attr-defined] from grimp.domain.analysis import PackageDependency, Route from grimp.domain.valueobjects import Layer -from grimp import _rustgrimp as rust # type: ignore[attr-defined] from grimp.exceptions import ( + InvalidImportExpression, + InvalidModuleExpression, ModuleNotPresent, NoSuchContainer, - InvalidModuleExpression, - InvalidImportExpression, ) diff --git a/src/grimp/application/usecases.py b/src/grimp/application/usecases.py index e7503d29..8ca4f6b4 100644 --- a/src/grimp/application/usecases.py +++ b/src/grimp/application/usecases.py @@ -2,17 +2,17 @@ Use cases handle application logic. """ +from collections.abc import Iterable, Sequence from typing import cast -from collections.abc import Sequence, Iterable -from .scanning import scan_imports +from ..application.graph import ImportGraph from ..application.ports import caching from ..application.ports.filesystem import AbstractFileSystem, BasicFileSystem -from ..application.graph import ImportGraph from ..application.ports.modulefinder import AbstractModuleFinder, FoundPackage, ModuleFile from ..application.ports.packagefinder import AbstractPackageFinder from ..domain.valueobjects import DirectImport, Module from .config import settings +from .scanning import scan_imports class NotSupplied: @@ -68,6 +68,43 @@ def build_graph( return graph +def build_graph_rust( + package_name, + *additional_package_names, + include_external_packages: bool = False, + exclude_type_checking_imports: bool = False, + cache_dir: str | type[NotSupplied] | None = NotSupplied, +) -> ImportGraph: + """ + Build and return an import graph for the supplied package(s) using the Rust implementation. + """ + from grimp import _rustgrimp as rust # type: ignore[attr-defined] + + file_system: AbstractFileSystem = settings.FILE_SYSTEM + package_finder: AbstractPackageFinder = settings.PACKAGE_FINDER + + # Determine the package directory + package_directory = package_finder.determine_package_directory( + package_name=package_name, file_system=file_system + ) + + # Create the graph_builder + package_spec = rust.PackageSpec(package_name, package_directory) + graph_builder = rust.GraphBuilder(package_spec) + if include_external_packages: + graph_builder = graph_builder.include_external_packages(True) + if exclude_type_checking_imports: + graph_builder = graph_builder.exclude_type_checking_imports(True) + + # Build the graph + rust_graph = graph_builder.build() + + # Wrap the rust graph in our ImportGraph wrapper + graph = ImportGraph() + graph._rustgraph = rust_graph + return graph + + def _find_packages( file_system: AbstractFileSystem, package_names: Sequence[object] ) -> set[FoundPackage]: diff --git a/src/grimp/main.py b/src/grimp/main.py index 45a297a0..88c5fc7f 100644 --- a/src/grimp/main.py +++ b/src/grimp/main.py @@ -1,13 +1,13 @@ -__all__ = ["build_graph"] +__all__ = ["build_graph", "build_graph_rust"] from .adaptors.caching import Cache from .adaptors.filesystem import FileSystem -from .application.graph import ImportGraph from .adaptors.modulefinder import ModuleFinder from .adaptors.packagefinder import ImportLibPackageFinder from .adaptors.timing import SystemClockTimer from .application.config import settings -from .application.usecases import build_graph +from .application.graph import ImportGraph +from .application.usecases import build_graph, build_graph_rust settings.configure( MODULE_FINDER=ModuleFinder(), diff --git a/tests/functional/test_build_graph_on_real_packages.py b/tests/functional/test_build_graph_on_real_packages.py index b2da12af..0f92f77d 100644 --- a/tests/functional/test_build_graph_on_real_packages.py +++ b/tests/functional/test_build_graph_on_real_packages.py @@ -12,6 +12,11 @@ def test_build_graph_on_real_package(package_name, snapshot): imports = graph.find_matching_direct_imports(import_expression="** -> **") assert imports == snapshot + graph_from_rust = grimp.build_graph_rust(package_name, cache_dir=None) + imports_from_rust = graph_from_rust.find_matching_direct_imports(import_expression="** -> **") + sort_key = lambda i: (i["importer"], i["imported"]) # noqa:E731 + assert sorted(imports_from_rust, key=sort_key) == sorted(imports, key=sort_key) + @pytest.mark.parametrize( "package_name", From 56545662cac428b095bc7cf4b9613b0558b7faa2 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 10:58:59 +0100 Subject: [PATCH 03/19] TEMP Add benchmark script --- benchmark_build_graph.py | 81 ++++++++++++++++++++++++++++++++++++++++ justfile | 10 ++++- 2 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 benchmark_build_graph.py diff --git a/benchmark_build_graph.py b/benchmark_build_graph.py new file mode 100644 index 00000000..15c73550 --- /dev/null +++ b/benchmark_build_graph.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +"""Benchmark build_graph vs build_graph_rust.""" + +import argparse +import os +import sys +import time + +import grimp + + +def benchmark_build_graph(package_name: str, working_dir: str | None = None) -> None: + """Benchmark both graph building implementations.""" + if working_dir: + os.chdir(working_dir) + print(f"Changed directory to: {working_dir}") + # Add working directory to Python path + if working_dir not in sys.path: + sys.path.insert(0, working_dir) + print(f"Added to PYTHONPATH: {working_dir}\n") + + print(f"Benchmarking graph building for package: {package_name}") + print("=" * 60) + + # Benchmark Python version + print("\nPython version (build_graph):") + start = time.perf_counter() + graph_py = grimp.build_graph(package_name, cache_dir=None) + elapsed_py = time.perf_counter() - start + + modules_py = len(graph_py.modules) + imports_py = len(graph_py.find_matching_direct_imports(import_expression="** -> **")) + + print(f" Time: {elapsed_py:.3f}s") + print(f" Modules: {modules_py}") + print(f" Imports: {imports_py}") + + # Benchmark Rust version + print("\nRust version (build_graph_rust):") + start = time.perf_counter() + graph_rust = grimp.build_graph_rust(package_name, cache_dir=None) + elapsed_rust = time.perf_counter() - start + + modules_rust = len(graph_rust.modules) + imports_rust = len(graph_rust.find_matching_direct_imports(import_expression="** -> **")) + + print(f" Time: {elapsed_rust:.3f}s") + print(f" Modules: {modules_rust}") + print(f" Imports: {imports_rust}") + + # Compare + print("\n" + "=" * 60) + print("Comparison:") + speedup = elapsed_py / elapsed_rust if elapsed_rust > 0 else float("inf") + print(f" Speedup: {speedup:.2f}x") + print(f" Python: {elapsed_py:.3f}s") + print(f" Rust: {elapsed_rust:.3f}s") + + # Verify correctness + if modules_py != modules_rust: + print(f"\n⚠️ Warning: Module count mismatch ({modules_py} vs {modules_rust})") + if imports_py != imports_rust: + print(f"⚠️ Warning: Import count mismatch ({imports_py} vs {imports_rust})") + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark build_graph vs build_graph_rust") + parser.add_argument("package", help="Package name to analyze") + parser.add_argument("-d", "--directory", help="Working directory to change to before running") + + args = parser.parse_args() + + try: + benchmark_build_graph(args.package, args.directory) + except Exception as e: + print(f"\nError: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/justfile b/justfile index b0db10a3..b2622a90 100644 --- a/justfile +++ b/justfile @@ -12,6 +12,11 @@ install-precommit: compile: @uv run maturin develop +# Compiles the Rust code in release mode. +[group('testing')] +compile-release: + @uv run maturin develop --release + # Compiles Rust, then runs Rust and Python tests. [group('testing')] compile-and-test: @@ -169,4 +174,7 @@ full-check: @just lint @just build-docs @just test-all - @echo '👍 {{GREEN}} Linting, docs and tests all good.{{NORMAL}}' \ No newline at end of file + @echo '👍 {{GREEN}} Linting, docs and tests all good.{{NORMAL}}' + +benchmark-build-graph-rust-vs-python package_name package_dir: compile-release + uv run benchmark_build_graph.py {{package_name}} -d {{package_dir}} From ab992a8532c5ab8590e07081e60a1990a252962c Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 11:23:30 +0100 Subject: [PATCH 04/19] Implement parallelism in graph building --- rust/Cargo.lock | 32 ++++++++++++++ rust/Cargo.toml | 1 + rust/src/graph/builder/mod.rs | 82 +++++++++++++++++++++++++---------- 3 files changed, 93 insertions(+), 22 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 89dd88da..d1597804 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -8,6 +8,7 @@ version = "0.1.0" dependencies = [ "bimap", "const_format", + "crossbeam", "derive-new", "encoding_rs", "getset", @@ -96,6 +97,28 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -115,6 +138,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index f81a2665..6f00b3b0 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -30,6 +30,7 @@ serde_yaml = "0.9" unindent = "0.2.4" encoding_rs = "0.8.35" ignore = "0.4" +crossbeam = "0.8" [dependencies.pyo3] version = "0.26.0" diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index dd549e37..cbb02c41 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -1,7 +1,9 @@ use std::collections::{HashMap, HashSet}; use std::fs; use std::path::PathBuf; +use std::thread; +use crossbeam::channel; use derive_new::new; use ignore::WalkBuilder; @@ -48,17 +50,64 @@ impl GraphBuilder { } pub fn build(&self) -> Graph { - let modules = discover_python_modules(&self.package); + // Create channels for communication + let (module_discovery_sender, module_discovery_receiver) = channel::bounded(10000); + let (import_parser_sender, import_parser_receiver) = channel::bounded(10000); + + let mut thread_handles = Vec::new(); + + // Thread 1: Discover modules + let package = self.package.clone(); + let handle = thread::spawn(move || { + let modules = discover_python_modules(&package); + // Send modules to parser threads + for module in modules { + module_discovery_sender.send(module).unwrap(); + } + drop(module_discovery_sender); // Close channel to signal completion + }); + thread_handles.push(handle); + + // Thread pool: Parse imports + let num_workers = thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(4); + for _ in 0..num_workers { + let receiver = module_discovery_receiver.clone(); + let sender = import_parser_sender.clone(); + let handle = thread::spawn(move || { + while let Ok(module) = receiver.recv() { + if let Some(parsed) = parse_module_imports(&module) { + sender.send(parsed).unwrap(); + } + } + }); + thread_handles.push(handle); + } + drop(module_discovery_receiver); // Close original receiver + drop(import_parser_sender); // Close original sender - let parsed_modules = parse_imports(&modules); + // Collect parsed modules + let mut parsed_modules = Vec::new(); + while let Ok(parsed) = import_parser_receiver.recv() { + parsed_modules.push(parsed); + } + // Wait for all threads to complete + for handle in thread_handles { + handle.join().unwrap(); + } + + // Resolve imports and assemble graph (sequential) let imports_by_module = resolve_imports( &parsed_modules, self.include_external_packages, self.exclude_type_checking_imports, ); - assemble_graph(&imports_by_module, &self.package.name) + let graph = assemble_graph(&imports_by_module, &self.package.name); + + graph } } @@ -111,25 +160,14 @@ fn discover_python_modules(package: &PackageSpec) -> Vec { modules } -fn parse_imports(modules: &[FoundModule]) -> Vec { - let mut parsed_modules = Vec::new(); - - for module in modules { - // Read the file - if let Ok(code) = fs::read_to_string(&module.path) { - // Parse imports - if let Ok(imported_objects) = - parse_imports_from_code(&code, module.path.to_str().unwrap_or("")) - { - parsed_modules.push(ParsedModule { - module: module.clone(), - imported_objects, - }); - } - } - } - - parsed_modules +fn parse_module_imports(module: &FoundModule) -> Option { + let code = fs::read_to_string(&module.path).ok()?; + let imported_objects = + parse_imports_from_code(&code, module.path.to_str().unwrap_or("")).ok()?; + Some(ParsedModule { + module: module.clone(), + imported_objects, + }) } fn resolve_imports( From f2032621c8a8e708d529e92955aab53cd5956120 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 11:56:49 +0100 Subject: [PATCH 05/19] Implement caching --- benchmark_build_graph.py | 174 +++++++++++++++++++++++------- rust/Cargo.lock | 33 ++++++ rust/Cargo.toml | 1 + rust/src/graph/builder/cache.rs | 67 ++++++++++++ rust/src/graph/builder/mod.rs | 61 +++++++++-- rust/src/graph_building.rs | 5 + rust/src/import_parsing.rs | 4 +- src/grimp/application/usecases.py | 7 ++ 8 files changed, 309 insertions(+), 43 deletions(-) create mode 100644 rust/src/graph/builder/cache.rs diff --git a/benchmark_build_graph.py b/benchmark_build_graph.py index 15c73550..aa52e036 100644 --- a/benchmark_build_graph.py +++ b/benchmark_build_graph.py @@ -3,12 +3,98 @@ import argparse import os +import shutil import sys import time +from dataclasses import dataclass +from typing import Callable import grimp +@dataclass +class BenchmarkResult: + """Result of a single benchmark run.""" + + name: str + elapsed: float + modules: int + imports: int + + +def run_benchmark( + name: str, + build_func: Callable, + package_name: str, + cache_dir: str | None, +) -> BenchmarkResult: + """Run a single benchmark and return the result.""" + print(f"\n{name}:") + start = time.perf_counter() + graph = build_func(package_name, cache_dir=cache_dir) + elapsed = time.perf_counter() - start + + modules = len(graph.modules) + imports = len(graph.find_matching_direct_imports(import_expression="** -> **")) + + print(f" Time: {elapsed:.3f}s") + print(f" Modules: {modules}") + print(f" Imports: {imports}") + + return BenchmarkResult(name, elapsed, modules, imports) + + +def cleanup_cache_dir(cache_dir: str) -> None: + """Remove cache directory.""" + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + + +def print_comparison( + py_results: list[BenchmarkResult], rust_results: list[BenchmarkResult] +) -> None: + """Print comparison of benchmark results.""" + py_no_cache, py_cold, py_warm = py_results + rust_no_cache, rust_cold, rust_warm = rust_results + + print("\n" + "=" * 60) + print("Comparison:") + print(f" Python (no cache): {py_no_cache.elapsed:.3f}s") + print( + f" Python (cold cache): {py_cold.elapsed:.3f}s " + f"({py_no_cache.elapsed / py_cold.elapsed:.2f}x speedup)" + ) + print( + f" Python (warm cache): {py_warm.elapsed:.3f}s " + f"({py_no_cache.elapsed / py_warm.elapsed:.2f}x speedup)" + ) + print( + f" Rust (no cache): {rust_no_cache.elapsed:.3f}s " + f"({py_no_cache.elapsed / rust_no_cache.elapsed:.2f}x vs Python no cache)" + ) + print( + f" Rust (cold cache): {rust_cold.elapsed:.3f}s " + f"({py_no_cache.elapsed / rust_cold.elapsed:.2f}x vs Python no cache)" + ) + print( + f" Rust (warm cache): {rust_warm.elapsed:.3f}s " + f"({py_no_cache.elapsed / rust_warm.elapsed:.2f}x vs Python no cache)" + ) + print(f"\n Python cache speedup: {py_no_cache.elapsed / py_warm.elapsed:.2f}x") + print(f" Rust cache speedup: {rust_no_cache.elapsed / rust_warm.elapsed:.2f}x") + + # Verify correctness + if py_no_cache.modules != rust_no_cache.modules: + print( + f"\n⚠️ Warning: Module count mismatch " + f"({py_no_cache.modules} vs {rust_no_cache.modules})" + ) + if py_no_cache.imports != rust_no_cache.imports: + print( + f"⚠️ Warning: Import count mismatch ({py_no_cache.imports} vs {rust_no_cache.imports})" + ) + + def benchmark_build_graph(package_name: str, working_dir: str | None = None) -> None: """Benchmark both graph building implementations.""" if working_dir: @@ -22,45 +108,61 @@ def benchmark_build_graph(package_name: str, working_dir: str | None = None) -> print(f"Benchmarking graph building for package: {package_name}") print("=" * 60) - # Benchmark Python version - print("\nPython version (build_graph):") - start = time.perf_counter() - graph_py = grimp.build_graph(package_name, cache_dir=None) - elapsed_py = time.perf_counter() - start - - modules_py = len(graph_py.modules) - imports_py = len(graph_py.find_matching_direct_imports(import_expression="** -> **")) + cache_dir = ".grimp_cache_benchmark" - print(f" Time: {elapsed_py:.3f}s") - print(f" Modules: {modules_py}") - print(f" Imports: {imports_py}") + # Benchmark Python version + py_no_cache = run_benchmark( + "Python version without cache (build_graph)", + grimp.build_graph, + package_name, + None, + ) + + cleanup_cache_dir(cache_dir) + py_cold = run_benchmark( + "Python version with cache - first run (cold cache)", + grimp.build_graph, + package_name, + cache_dir, + ) + + py_warm = run_benchmark( + "Python version with cache - second run (warm cache)", + grimp.build_graph, + package_name, + cache_dir, + ) # Benchmark Rust version - print("\nRust version (build_graph_rust):") - start = time.perf_counter() - graph_rust = grimp.build_graph_rust(package_name, cache_dir=None) - elapsed_rust = time.perf_counter() - start - - modules_rust = len(graph_rust.modules) - imports_rust = len(graph_rust.find_matching_direct_imports(import_expression="** -> **")) - - print(f" Time: {elapsed_rust:.3f}s") - print(f" Modules: {modules_rust}") - print(f" Imports: {imports_rust}") - - # Compare - print("\n" + "=" * 60) - print("Comparison:") - speedup = elapsed_py / elapsed_rust if elapsed_rust > 0 else float("inf") - print(f" Speedup: {speedup:.2f}x") - print(f" Python: {elapsed_py:.3f}s") - print(f" Rust: {elapsed_rust:.3f}s") - - # Verify correctness - if modules_py != modules_rust: - print(f"\n⚠️ Warning: Module count mismatch ({modules_py} vs {modules_rust})") - if imports_py != imports_rust: - print(f"⚠️ Warning: Import count mismatch ({imports_py} vs {imports_rust})") + rust_no_cache = run_benchmark( + "Rust version without cache (build_graph_rust)", + grimp.build_graph_rust, + package_name, + None, + ) + + cleanup_cache_dir(cache_dir) + rust_cold = run_benchmark( + "Rust version with cache - first run (cold cache)", + grimp.build_graph_rust, + package_name, + cache_dir, + ) + + rust_warm = run_benchmark( + "Rust version with cache - second run (warm cache)", + grimp.build_graph_rust, + package_name, + cache_dir, + ) + + cleanup_cache_dir(cache_dir) + + # Print comparison + print_comparison( + [py_no_cache, py_cold, py_warm], + [rust_no_cache, rust_cold, rust_warm], + ) def main(): diff --git a/rust/Cargo.lock b/rust/Cargo.lock index d1597804..104ffb89 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -7,6 +7,7 @@ name = "_rustgrimp" version = "0.1.0" dependencies = [ "bimap", + "bincode", "const_format", "crossbeam", "derive-new", @@ -54,6 +55,26 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230c5f1ca6a325a32553f8640d31ac9b49f2411e901e427570154868b46da4f7" +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bitflags" version = "2.9.4" @@ -904,12 +925,24 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "version_check" version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "walkdir" version = "2.5.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 6f00b3b0..96cb5a57 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -31,6 +31,7 @@ unindent = "0.2.4" encoding_rs = "0.8.35" ignore = "0.4" crossbeam = "0.8" +bincode = "2.0.0-rc.3" [dependencies.pyo3] version = "0.26.0" diff --git a/rust/src/graph/builder/cache.rs b/rust/src/graph/builder/cache.rs new file mode 100644 index 00000000..c4af53f9 --- /dev/null +++ b/rust/src/graph/builder/cache.rs @@ -0,0 +1,67 @@ +use std::collections::HashMap; +use std::fs; +use std::io::{Read as _, Write as _}; +use std::path::{Path, PathBuf}; + +use bincode::{Decode, Encode}; + +use crate::import_parsing::ImportedObject; + +#[derive(Debug, Clone, Encode, Decode)] +pub struct CachedImports { + mtime_secs: i64, + imported_objects: Vec, +} + +impl CachedImports { + pub fn new(mtime_secs: i64, imported_objects: Vec) -> Self { + Self { + mtime_secs, + imported_objects, + } + } + + pub fn mtime_secs(&self) -> i64 { + self.mtime_secs + } + + pub fn imported_objects(&self) -> &[ImportedObject] { + &self.imported_objects + } +} + +pub type ImportCache = HashMap; + +pub fn load_cache(cache_dir: &Path, package_name: &str) -> ImportCache { + let cache_file = cache_dir.join(format!("{}.imports.bincode", package_name)); + + if let Ok(mut file) = fs::File::open(&cache_file) { + let mut buffer = Vec::new(); + if file.read_to_end(&mut buffer).is_ok() + && let Ok(cache) = + bincode::decode_from_slice::(&buffer, bincode::config::standard()) + { + return cache.0; + } + } + + HashMap::new() +} + +pub fn save_cache(cache: &ImportCache, cache_dir: &Path, package_name: &str) { + if let Err(e) = fs::create_dir_all(cache_dir) { + eprintln!("Failed to create cache directory: {}", e); + return; + } + + let cache_file = cache_dir.join(format!("{}.imports.bincode", package_name)); + + match bincode::encode_to_vec(cache, bincode::config::standard()) { + Ok(encoded) => { + if let Ok(mut file) = fs::File::create(&cache_file) { + let _ = file.write_all(&encoded); + } + } + Err(e) => eprintln!("Failed to encode cache: {}", e), + } +} diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index cbb02c41..065a2fce 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -10,6 +10,9 @@ use ignore::WalkBuilder; use crate::graph::Graph; use crate::import_parsing::{ImportedObject, parse_imports_from_code}; +mod cache; +use cache::{CachedImports, ImportCache, load_cache, save_cache}; + mod utils; use utils::{ ResolvedImport, is_internal, is_package, path_to_module_name, resolve_external_module, @@ -27,7 +30,7 @@ pub struct GraphBuilder { package: PackageSpec, // TODO(peter) Support multiple packages include_external_packages: bool, exclude_type_checking_imports: bool, - // cache_dir: Option // TODO(peter) + cache_dir: Option, } impl GraphBuilder { @@ -36,6 +39,7 @@ impl GraphBuilder { package, include_external_packages: false, exclude_type_checking_imports: false, + cache_dir: None, } } @@ -49,7 +53,19 @@ impl GraphBuilder { self } + pub fn cache_dir(mut self, cache_dir: Option) -> Self { + self.cache_dir = cache_dir; + self + } + pub fn build(&self) -> Graph { + // Load cache if available + let mut cache = self + .cache_dir + .as_ref() + .map(|dir| load_cache(dir, &self.package.name)) + .unwrap_or_default(); + // Create channels for communication let (module_discovery_sender, module_discovery_receiver) = channel::bounded(10000); let (import_parser_sender, import_parser_receiver) = channel::bounded(10000); @@ -75,9 +91,10 @@ impl GraphBuilder { for _ in 0..num_workers { let receiver = module_discovery_receiver.clone(); let sender = import_parser_sender.clone(); + let cache = cache.clone(); let handle = thread::spawn(move || { while let Ok(module) = receiver.recv() { - if let Some(parsed) = parse_module_imports(&module) { + if let Some(parsed) = parse_module_imports(&module, &cache) { sender.send(parsed).unwrap(); } } @@ -98,6 +115,17 @@ impl GraphBuilder { handle.join().unwrap(); } + // Update and save cache if cache_dir is set + if let Some(cache_dir) = &self.cache_dir { + for parsed in &parsed_modules { + cache.insert( + parsed.module.path.clone(), + CachedImports::new(parsed.module.mtime_secs, parsed.imported_objects.clone()), + ); + } + save_cache(&cache, cache_dir, &self.package.name); + } + // Resolve imports and assemble graph (sequential) let imports_by_module = resolve_imports( &parsed_modules, @@ -105,9 +133,7 @@ impl GraphBuilder { self.exclude_type_checking_imports, ); - let graph = assemble_graph(&imports_by_module, &self.package.name); - - graph + assemble_graph(&imports_by_module, &self.package.name) } } @@ -116,6 +142,7 @@ struct FoundModule { name: String, path: PathBuf, is_package: bool, + mtime_secs: i64, } #[derive(Debug)] @@ -149,10 +176,20 @@ fn discover_python_modules(package: &PackageSpec) -> Vec { let path = entry.path(); if let Some(module_name) = path_to_module_name(path, package) { let is_package = is_package(path); + + // Get mtime + let mtime_secs = fs::metadata(path) + .and_then(|m| m.modified()) + .ok() + .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) + .map(|d| d.as_secs() as i64) + .unwrap_or(0); + modules.push(FoundModule { name: module_name, path: path.to_owned(), is_package, + mtime_secs, }); } } @@ -160,7 +197,19 @@ fn discover_python_modules(package: &PackageSpec) -> Vec { modules } -fn parse_module_imports(module: &FoundModule) -> Option { +fn parse_module_imports(module: &FoundModule, cache: &ImportCache) -> Option { + // Check if we have a cached version with matching mtime + if let Some(cached) = cache.get(&module.path) + && module.mtime_secs == cached.mtime_secs() + { + // Cache hit - use cached imports + return Some(ParsedModule { + module: module.clone(), + imported_objects: cached.imported_objects().to_vec(), + }); + } + + // Cache miss or file modified - parse the file let code = fs::read_to_string(&module.path).ok()?; let imported_objects = parse_imports_from_code(&code, module.path.to_str().unwrap_or("")).ok()?; diff --git a/rust/src/graph_building.rs b/rust/src/graph_building.rs index 8767338e..813019a0 100644 --- a/rust/src/graph_building.rs +++ b/rust/src/graph_building.rs @@ -47,6 +47,11 @@ impl PyGraphBuilder { self_ } + fn cache_dir(mut self_: PyRefMut<'_, Self>, cache_dir: Option) -> PyRefMut<'_, Self> { + self_.inner = self_.inner.clone().cache_dir(cache_dir.map(PathBuf::from)); + self_ + } + fn build(&self) -> GraphWrapper { let graph = self.inner.build(); GraphWrapper::from_graph(graph) diff --git a/rust/src/import_parsing.rs b/rust/src/import_parsing.rs index a43f7422..0fa64b08 100644 --- a/rust/src/import_parsing.rs +++ b/rust/src/import_parsing.rs @@ -1,10 +1,12 @@ +use bincode::{Decode, Encode}; + use crate::errors::{GrimpError, GrimpResult}; use ruff_python_ast::statement_visitor::{StatementVisitor, walk_body, walk_stmt}; use ruff_python_ast::{Expr, Stmt}; use ruff_python_parser::parse_module; use ruff_source_file::{LineIndex, SourceCode}; -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone, Encode, Decode)] pub struct ImportedObject { pub name: String, pub line_number: usize, diff --git a/src/grimp/application/usecases.py b/src/grimp/application/usecases.py index 8ca4f6b4..cb5ceca2 100644 --- a/src/grimp/application/usecases.py +++ b/src/grimp/application/usecases.py @@ -96,6 +96,13 @@ def build_graph_rust( if exclude_type_checking_imports: graph_builder = graph_builder.exclude_type_checking_imports(True) + # Handle cache_dir + if cache_dir is not None: + if cache_dir is NotSupplied: + graph_builder = graph_builder.cache_dir(".grimp_cache") + else: + graph_builder = graph_builder.cache_dir(cache_dir) + # Build the graph rust_graph = graph_builder.build() From a13685170998f3c0d772d60615d49fe880fc811e Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 12:27:25 +0100 Subject: [PATCH 06/19] Also use parallelism to discover Python modules --- rust/src/graph/builder/mod.rs | 83 +++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 065a2fce..d1f746a8 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use std::collections::{HashMap, HashSet}; use std::fs; use std::path::PathBuf; @@ -75,18 +76,13 @@ impl GraphBuilder { // Thread 1: Discover modules let package = self.package.clone(); let handle = thread::spawn(move || { - let modules = discover_python_modules(&package); - // Send modules to parser threads - for module in modules { - module_discovery_sender.send(module).unwrap(); - } - drop(module_discovery_sender); // Close channel to signal completion + discover_python_modules(&package, module_discovery_sender); }); thread_handles.push(handle); // Thread pool: Parse imports let num_workers = thread::available_parallelism() - .map(|n| n.get()) + .map(|n| max(n.get() / 2, 1)) .unwrap_or(4); for _ in 0..num_workers { let receiver = module_discovery_receiver.clone(); @@ -151,11 +147,16 @@ struct ParsedModule { imported_objects: Vec, } -fn discover_python_modules(package: &PackageSpec) -> Vec { - let mut modules = Vec::new(); +fn discover_python_modules(package: &PackageSpec, sender: channel::Sender) { + let num_threads = thread::available_parallelism() + .map(|n| max(n.get() / 2, 1)) + .unwrap_or(4); + + let package_clone = package.clone(); - let walker = WalkBuilder::new(&package.directory) + WalkBuilder::new(&package.directory) .standard_filters(false) // Don't use gitignore or other filters + .threads(num_threads) .filter_entry(|entry| { // Allow Python files if entry.file_type().is_some_and(|ft| ft.is_file()) { @@ -170,31 +171,45 @@ fn discover_python_modules(package: &PackageSpec) -> Vec { false }) - .build(); - - for entry in walker.flatten() { - let path = entry.path(); - if let Some(module_name) = path_to_module_name(path, package) { - let is_package = is_package(path); - - // Get mtime - let mtime_secs = fs::metadata(path) - .and_then(|m| m.modified()) - .ok() - .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) - .map(|d| d.as_secs() as i64) - .unwrap_or(0); - - modules.push(FoundModule { - name: module_name, - path: path.to_owned(), - is_package, - mtime_secs, - }); - } - } + .build_parallel() + .run(|| { + let sender = sender.clone(); + let package = package_clone.clone(); + + Box::new(move |entry| { + use ignore::WalkState; + + let entry = match entry { + Ok(e) => e, + Err(_) => return WalkState::Continue, + }; + + let path = entry.path(); + if let Some(module_name) = path_to_module_name(path, &package) { + let is_package = is_package(path); + + // Get mtime + let mtime_secs = fs::metadata(path) + .and_then(|m| m.modified()) + .ok() + .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) + .map(|d| d.as_secs() as i64) + .unwrap_or(0); + + let found_module = FoundModule { + name: module_name, + path: path.to_owned(), + is_package, + mtime_secs, + }; + + // Send module as soon as we discover it + let _ = sender.send(found_module); + } - modules + WalkState::Continue + }) + }); } fn parse_module_imports(module: &FoundModule, cache: &ImportCache) -> Option { From a539f95dd5bb5d9d46628f16748017fc6dc252eb Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 12:33:02 +0100 Subject: [PATCH 07/19] TEMP Update benchmarks to use build_graph_rust --- tests/benchmarking/test_benchmarking.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/benchmarking/test_benchmarking.py b/tests/benchmarking/test_benchmarking.py index aece8dad..25c1f4b4 100644 --- a/tests/benchmarking/test_benchmarking.py +++ b/tests/benchmarking/test_benchmarking.py @@ -1,14 +1,15 @@ -import uuid -import random -import pytest -import json import importlib +import json +import random +import uuid +from copy import deepcopy from pathlib import Path -from grimp.application.graph import ImportGraph -from grimp import PackageDependency, Route +import pytest + import grimp -from copy import deepcopy +from grimp import PackageDependency, Route +from grimp.application.graph import ImportGraph @pytest.fixture(scope="module") @@ -303,7 +304,7 @@ def test_build_django_uncached(benchmark): In this benchmark, the cache is turned off. """ - benchmark(grimp.build_graph, "django", cache_dir=None) + benchmark(grimp.build_graph_rust, "django", cache_dir=None) def test_build_django_from_cache_no_misses(benchmark): @@ -313,9 +314,9 @@ def test_build_django_from_cache_no_misses(benchmark): This benchmark fully utilizes the cache. """ # Populate the cache first, before beginning the benchmark. - grimp.build_graph("django") + grimp.build_graph_rust("django") - benchmark(grimp.build_graph, "django") + benchmark(grimp.build_graph_rust, "django") @pytest.mark.parametrize( @@ -364,7 +365,7 @@ def test_build_django_from_cache_a_few_misses(benchmark, number_of_misses: int): # turn off multiple runs, which could potentially be misleading when running locally. # Populate the cache first, before beginning the benchmark. - grimp.build_graph("django") + grimp.build_graph_rust("django") # Add some modules which won't be in the cache. # (Use some real python, which will take time to parse.) django_path = Path(importlib.util.find_spec("django").origin).parent # type: ignore @@ -380,7 +381,7 @@ def test_build_django_from_cache_a_few_misses(benchmark, number_of_misses: int): hash_buster = f"\n# Hash busting comment: {uuid.uuid4()}" new_module.write_text(module_contents + hash_buster) - benchmark.pedantic(grimp.build_graph, ["django"], rounds=1, iterations=1) + benchmark.pedantic(grimp.build_graph_rust, ["django"], rounds=1, iterations=1) # Delete the modules we just created. for module in extra_modules: From 16b59b22d3b0d937c35ce647f555867cd7b1f15b Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 14:40:53 +0100 Subject: [PATCH 08/19] Make build_graph a simple function More pythonic than builder pattern. --- rust/src/graph/builder/mod.rs | 149 ++++++++++++------------------ rust/src/graph_building.rs | 54 ++++------- rust/src/lib.rs | 2 +- src/grimp/application/usecases.py | 24 ++--- 4 files changed, 89 insertions(+), 140 deletions(-) diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index d1f746a8..8b7da59c 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -26,111 +26,80 @@ pub struct PackageSpec { directory: PathBuf, } -#[derive(Debug, Clone)] -pub struct GraphBuilder { - package: PackageSpec, // TODO(peter) Support multiple packages +pub fn build_graph( + package: &PackageSpec, // TODO(peter) Support multiple packages include_external_packages: bool, exclude_type_checking_imports: bool, - cache_dir: Option, -} - -impl GraphBuilder { - pub fn new(package: PackageSpec) -> Self { - GraphBuilder { - package, - include_external_packages: false, - exclude_type_checking_imports: false, - cache_dir: None, - } - } - - pub fn include_external_packages(mut self, yes: bool) -> Self { - self.include_external_packages = yes; - self - } - - pub fn exclude_type_checking_imports(mut self, yes: bool) -> Self { - self.exclude_type_checking_imports = yes; - self - } - - pub fn cache_dir(mut self, cache_dir: Option) -> Self { - self.cache_dir = cache_dir; - self - } + cache_dir: Option<&PathBuf>, +) -> Graph { + // Load cache if available + let mut cache = cache_dir + .map(|dir| load_cache(dir, &package.name)) + .unwrap_or_default(); - pub fn build(&self) -> Graph { - // Load cache if available - let mut cache = self - .cache_dir - .as_ref() - .map(|dir| load_cache(dir, &self.package.name)) - .unwrap_or_default(); + // Create channels for communication + let (module_discovery_sender, module_discovery_receiver) = channel::bounded(10000); + let (import_parser_sender, import_parser_receiver) = channel::bounded(10000); - // Create channels for communication - let (module_discovery_sender, module_discovery_receiver) = channel::bounded(10000); - let (import_parser_sender, import_parser_receiver) = channel::bounded(10000); + let mut thread_handles = Vec::new(); - let mut thread_handles = Vec::new(); + // Thread 1: Discover modules + let package_clone = package.clone(); + let handle = thread::spawn(move || { + discover_python_modules(&package_clone, module_discovery_sender); + }); + thread_handles.push(handle); - // Thread 1: Discover modules - let package = self.package.clone(); + // Thread pool: Parse imports + let num_workers = thread::available_parallelism() + .map(|n| max(n.get() / 2, 1)) + .unwrap_or(4); + for _ in 0..num_workers { + let receiver = module_discovery_receiver.clone(); + let sender = import_parser_sender.clone(); + let cache = cache.clone(); let handle = thread::spawn(move || { - discover_python_modules(&package, module_discovery_sender); + while let Ok(module) = receiver.recv() { + if let Some(parsed) = parse_module_imports(&module, &cache) { + sender.send(parsed).unwrap(); + } + } }); thread_handles.push(handle); + } + drop(module_discovery_receiver); // Close original receiver + drop(import_parser_sender); // Close original sender - // Thread pool: Parse imports - let num_workers = thread::available_parallelism() - .map(|n| max(n.get() / 2, 1)) - .unwrap_or(4); - for _ in 0..num_workers { - let receiver = module_discovery_receiver.clone(); - let sender = import_parser_sender.clone(); - let cache = cache.clone(); - let handle = thread::spawn(move || { - while let Ok(module) = receiver.recv() { - if let Some(parsed) = parse_module_imports(&module, &cache) { - sender.send(parsed).unwrap(); - } - } - }); - thread_handles.push(handle); - } - drop(module_discovery_receiver); // Close original receiver - drop(import_parser_sender); // Close original sender - - // Collect parsed modules - let mut parsed_modules = Vec::new(); - while let Ok(parsed) = import_parser_receiver.recv() { - parsed_modules.push(parsed); - } + // Collect parsed modules + let mut parsed_modules = Vec::new(); + while let Ok(parsed) = import_parser_receiver.recv() { + parsed_modules.push(parsed); + } - // Wait for all threads to complete - for handle in thread_handles { - handle.join().unwrap(); - } + // Wait for all threads to complete + for handle in thread_handles { + handle.join().unwrap(); + } - // Update and save cache if cache_dir is set - if let Some(cache_dir) = &self.cache_dir { - for parsed in &parsed_modules { - cache.insert( - parsed.module.path.clone(), - CachedImports::new(parsed.module.mtime_secs, parsed.imported_objects.clone()), - ); - } - save_cache(&cache, cache_dir, &self.package.name); + // Update and save cache if cache_dir is set + if let Some(cache_dir) = cache_dir { + for parsed in &parsed_modules { + cache.insert( + parsed.module.path.clone(), + CachedImports::new(parsed.module.mtime_secs, parsed.imported_objects.clone()), + ); } + save_cache(&cache, cache_dir, &package.name); + } - // Resolve imports and assemble graph (sequential) - let imports_by_module = resolve_imports( - &parsed_modules, - self.include_external_packages, - self.exclude_type_checking_imports, - ); + // Resolve imports and assemble graph (sequential) + let imports_by_module = resolve_imports( + &parsed_modules, + include_external_packages, + exclude_type_checking_imports, + ); - assemble_graph(&imports_by_module, &self.package.name) - } + assemble_graph(&imports_by_module, &package.name) } #[derive(Debug, Clone)] diff --git a/rust/src/graph_building.rs b/rust/src/graph_building.rs index 813019a0..b3072662 100644 --- a/rust/src/graph_building.rs +++ b/rust/src/graph_building.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; use std::path::PathBuf; use crate::graph::GraphWrapper; -use crate::graph::builder::{GraphBuilder, PackageSpec}; +use crate::graph::builder::{PackageSpec, build_graph}; #[pyclass(name = "PackageSpec")] #[derive(Clone)] @@ -20,40 +20,20 @@ impl PyPackageSpec { } } -#[pyclass(name = "GraphBuilder")] -pub struct PyGraphBuilder { - inner: GraphBuilder, -} - -#[pymethods] -impl PyGraphBuilder { - #[new] - fn new(package: PyPackageSpec) -> Self { - PyGraphBuilder { - inner: GraphBuilder::new(package.inner), - } - } - - fn include_external_packages(mut self_: PyRefMut<'_, Self>, yes: bool) -> PyRefMut<'_, Self> { - self_.inner = self_.inner.clone().include_external_packages(yes); - self_ - } - - fn exclude_type_checking_imports( - mut self_: PyRefMut<'_, Self>, - yes: bool, - ) -> PyRefMut<'_, Self> { - self_.inner = self_.inner.clone().exclude_type_checking_imports(yes); - self_ - } - - fn cache_dir(mut self_: PyRefMut<'_, Self>, cache_dir: Option) -> PyRefMut<'_, Self> { - self_.inner = self_.inner.clone().cache_dir(cache_dir.map(PathBuf::from)); - self_ - } - - fn build(&self) -> GraphWrapper { - let graph = self.inner.build(); - GraphWrapper::from_graph(graph) - } +#[pyfunction] +#[pyo3(signature = (package, include_external_packages=false, exclude_type_checking_imports=false, cache_dir=None))] +pub fn build_graph_rust( + package: &PyPackageSpec, + include_external_packages: bool, + exclude_type_checking_imports: bool, + cache_dir: Option, +) -> GraphWrapper { + let cache_path = cache_dir.map(PathBuf::from); + let graph = build_graph( + &package.inner, + include_external_packages, + exclude_type_checking_imports, + cache_path.as_ref(), + ); + GraphWrapper::from_graph(graph) } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 3b211383..35aa1736 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -34,5 +34,5 @@ mod _rustgrimp { }; #[pymodule_export] - use crate::graph_building::{PyGraphBuilder, PyPackageSpec}; + use crate::graph_building::{PyPackageSpec, build_graph_rust}; } diff --git a/src/grimp/application/usecases.py b/src/grimp/application/usecases.py index cb5ceca2..4eb05440 100644 --- a/src/grimp/application/usecases.py +++ b/src/grimp/application/usecases.py @@ -88,23 +88,23 @@ def build_graph_rust( package_name=package_name, file_system=file_system ) - # Create the graph_builder + # Create package spec and build the graph package_spec = rust.PackageSpec(package_name, package_directory) - graph_builder = rust.GraphBuilder(package_spec) - if include_external_packages: - graph_builder = graph_builder.include_external_packages(True) - if exclude_type_checking_imports: - graph_builder = graph_builder.exclude_type_checking_imports(True) # Handle cache_dir - if cache_dir is not None: - if cache_dir is NotSupplied: - graph_builder = graph_builder.cache_dir(".grimp_cache") - else: - graph_builder = graph_builder.cache_dir(cache_dir) + cache_dir_arg: str | None = None + if cache_dir is NotSupplied: + cache_dir_arg = ".grimp_cache" + elif isinstance(cache_dir, str): + cache_dir_arg = cache_dir # Build the graph - rust_graph = graph_builder.build() + rust_graph = rust.build_graph_rust( + package_spec, + include_external_packages=include_external_packages, + exclude_type_checking_imports=exclude_type_checking_imports, + cache_dir=cache_dir_arg, + ) # Wrap the rust graph in our ImportGraph wrapper graph = ImportGraph() From c80031c5c5f04b28504504277194fbe26c7e03e0 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 15:31:27 +0100 Subject: [PATCH 09/19] Implement error handling in build_graph --- rust/src/errors.rs | 20 ++++++++- rust/src/graph/builder/cache.rs | 36 ++++++++++------ rust/src/graph/builder/mod.rs | 76 ++++++++++++++++++++++++--------- rust/src/graph_building.rs | 6 +-- 4 files changed, 101 insertions(+), 37 deletions(-) diff --git a/rust/src/errors.rs b/rust/src/errors.rs index 2a894edb..9bd310a4 100644 --- a/rust/src/errors.rs +++ b/rust/src/errors.rs @@ -1,6 +1,6 @@ use crate::exceptions; use pyo3::PyErr; -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyFileNotFoundError, PyIOError, PyValueError}; use ruff_python_parser::ParseError as RuffParseError; use thiserror::Error; @@ -36,6 +36,18 @@ pub enum GrimpError { #[error("Could not use corrupt cache file {0}.")] CorruptCache(String), + + #[error("Failed to read file {path}: {error}")] + FileReadError { path: String, error: String }, + + #[error("Failed to get file metadata for {path}: {error}")] + FileMetadataError { path: String, error: String }, + + #[error("Failed to write cache file {path}: {error}")] + CacheWriteError { path: String, error: String }, + + #[error("Package directory does not exist: {0}")] + PackageDirectoryNotFound(String), } pub type GrimpResult = Result; @@ -55,6 +67,12 @@ impl From for PyErr { line_number, text, .. } => PyErr::new::((line_number, text)), GrimpError::CorruptCache(_) => exceptions::CorruptCache::new_err(value.to_string()), + GrimpError::FileReadError { .. } => PyIOError::new_err(value.to_string()), + GrimpError::FileMetadataError { .. } => PyIOError::new_err(value.to_string()), + GrimpError::CacheWriteError { .. } => PyIOError::new_err(value.to_string()), + GrimpError::PackageDirectoryNotFound(_) => { + PyFileNotFoundError::new_err(value.to_string()) + } } } } diff --git a/rust/src/graph/builder/cache.rs b/rust/src/graph/builder/cache.rs index c4af53f9..624b3ca7 100644 --- a/rust/src/graph/builder/cache.rs +++ b/rust/src/graph/builder/cache.rs @@ -5,6 +5,7 @@ use std::path::{Path, PathBuf}; use bincode::{Decode, Encode}; +use crate::errors::{GrimpError, GrimpResult}; use crate::import_parsing::ImportedObject; #[derive(Debug, Clone, Encode, Decode)] @@ -48,20 +49,31 @@ pub fn load_cache(cache_dir: &Path, package_name: &str) -> ImportCache { HashMap::new() } -pub fn save_cache(cache: &ImportCache, cache_dir: &Path, package_name: &str) { - if let Err(e) = fs::create_dir_all(cache_dir) { - eprintln!("Failed to create cache directory: {}", e); - return; - } +pub fn save_cache(cache: &ImportCache, cache_dir: &Path, package_name: &str) -> GrimpResult<()> { + fs::create_dir_all(cache_dir).map_err(|e| GrimpError::CacheWriteError { + path: cache_dir.display().to_string(), + error: e.to_string(), + })?; let cache_file = cache_dir.join(format!("{}.imports.bincode", package_name)); - match bincode::encode_to_vec(cache, bincode::config::standard()) { - Ok(encoded) => { - if let Ok(mut file) = fs::File::create(&cache_file) { - let _ = file.write_all(&encoded); - } + let encoded = bincode::encode_to_vec(cache, bincode::config::standard()).map_err(|e| { + GrimpError::CacheWriteError { + path: cache_file.display().to_string(), + error: e.to_string(), } - Err(e) => eprintln!("Failed to encode cache: {}", e), - } + })?; + + let mut file = fs::File::create(&cache_file).map_err(|e| GrimpError::CacheWriteError { + path: cache_file.display().to_string(), + error: e.to_string(), + })?; + + file.write_all(&encoded) + .map_err(|e| GrimpError::CacheWriteError { + path: cache_file.display().to_string(), + error: e.to_string(), + })?; + + Ok(()) } diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 8b7da59c..8af13fe7 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -8,6 +8,7 @@ use crossbeam::channel; use derive_new::new; use ignore::WalkBuilder; +use crate::errors::{GrimpError, GrimpResult}; use crate::graph::Graph; use crate::import_parsing::{ImportedObject, parse_imports_from_code}; @@ -31,22 +32,31 @@ pub fn build_graph( include_external_packages: bool, exclude_type_checking_imports: bool, cache_dir: Option<&PathBuf>, -) -> Graph { +) -> GrimpResult { + // Check if package directory exists + if !package.directory.exists() { + return Err(GrimpError::PackageDirectoryNotFound( + package.directory.display().to_string(), + )); + } + // Load cache if available let mut cache = cache_dir .map(|dir| load_cache(dir, &package.name)) .unwrap_or_default(); // Create channels for communication - let (module_discovery_sender, module_discovery_receiver) = channel::bounded(10000); - let (import_parser_sender, import_parser_receiver) = channel::bounded(10000); + // This way we can start parsing moduels while we're still discovering them. + let (found_module_sender, found_module_receiver) = channel::bounded(10000); + let (parsed_module_sender, parser_module_receiver) = channel::bounded(10000); + let (error_sender, error_receiver) = channel::bounded(1); let mut thread_handles = Vec::new(); // Thread 1: Discover modules let package_clone = package.clone(); let handle = thread::spawn(move || { - discover_python_modules(&package_clone, module_discovery_sender); + discover_python_modules(&package_clone, found_module_sender); }); thread_handles.push(handle); @@ -55,24 +65,33 @@ pub fn build_graph( .map(|n| max(n.get() / 2, 1)) .unwrap_or(4); for _ in 0..num_workers { - let receiver = module_discovery_receiver.clone(); - let sender = import_parser_sender.clone(); + let receiver = found_module_receiver.clone(); + let sender = parsed_module_sender.clone(); + let error_sender = error_sender.clone(); let cache = cache.clone(); let handle = thread::spawn(move || { while let Ok(module) = receiver.recv() { - if let Some(parsed) = parse_module_imports(&module, &cache) { - sender.send(parsed).unwrap(); + match parse_module_imports(&module, &cache) { + Ok(parsed) => { + let _ = sender.send(parsed); + } + Err(e) => { + // Channel has capacity of 1, since we only care to catch one error. + // Drop further errors. + let _ = error_sender.try_send(e); + } } } }); thread_handles.push(handle); } - drop(module_discovery_receiver); // Close original receiver - drop(import_parser_sender); // Close original sender - // Collect parsed modules + // Close original receivers/senders so threads know when to stop + drop(parsed_module_sender); // Main thread will know when no more parsed modules + + // Collect parsed modules (this will continue until all parser threads finish and close their senders) let mut parsed_modules = Vec::new(); - while let Ok(parsed) = import_parser_receiver.recv() { + while let Ok(parsed) = parser_module_receiver.recv() { parsed_modules.push(parsed); } @@ -81,6 +100,11 @@ pub fn build_graph( handle.join().unwrap(); } + // Check if any errors occurred + if let Ok(error) = error_receiver.try_recv() { + return Err(error); + } + // Update and save cache if cache_dir is set if let Some(cache_dir) = cache_dir { for parsed in &parsed_modules { @@ -89,17 +113,17 @@ pub fn build_graph( CachedImports::new(parsed.module.mtime_secs, parsed.imported_objects.clone()), ); } - save_cache(&cache, cache_dir, &package.name); + save_cache(&cache, cache_dir, &package.name)?; } - // Resolve imports and assemble graph (sequential) + // Resolve imports and assemble graph let imports_by_module = resolve_imports( &parsed_modules, include_external_packages, exclude_type_checking_imports, ); - assemble_graph(&imports_by_module, &package.name) + Ok(assemble_graph(&imports_by_module, &package.name)) } #[derive(Debug, Clone)] @@ -154,6 +178,12 @@ fn discover_python_modules(package: &PackageSpec, sender: channel::Sender Option { +fn parse_module_imports(module: &FoundModule, cache: &ImportCache) -> GrimpResult { // Check if we have a cached version with matching mtime if let Some(cached) = cache.get(&module.path) && module.mtime_secs == cached.mtime_secs() { // Cache hit - use cached imports - return Some(ParsedModule { + return Ok(ParsedModule { module: module.clone(), imported_objects: cached.imported_objects().to_vec(), }); } // Cache miss or file modified - parse the file - let code = fs::read_to_string(&module.path).ok()?; - let imported_objects = - parse_imports_from_code(&code, module.path.to_str().unwrap_or("")).ok()?; - Some(ParsedModule { + let code = fs::read_to_string(&module.path).map_err(|e| GrimpError::FileReadError { + path: module.path.display().to_string(), + error: e.to_string(), + })?; + + let imported_objects = parse_imports_from_code(&code, module.path.to_str().unwrap_or(""))?; + + Ok(ParsedModule { module: module.clone(), imported_objects, }) diff --git a/rust/src/graph_building.rs b/rust/src/graph_building.rs index b3072662..bcd4d631 100644 --- a/rust/src/graph_building.rs +++ b/rust/src/graph_building.rs @@ -27,13 +27,13 @@ pub fn build_graph_rust( include_external_packages: bool, exclude_type_checking_imports: bool, cache_dir: Option, -) -> GraphWrapper { +) -> PyResult { let cache_path = cache_dir.map(PathBuf::from); let graph = build_graph( &package.inner, include_external_packages, exclude_type_checking_imports, cache_path.as_ref(), - ); - GraphWrapper::from_graph(graph) + )?; + Ok(GraphWrapper::from_graph(graph)) } From ca84423e90dd3bd577fb66cee548c9e90bc67e58 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 15:35:17 +0100 Subject: [PATCH 10/19] Extract helper functions to get number of threads --- rust/src/graph/builder/mod.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 8af13fe7..4852088a 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -61,10 +61,8 @@ pub fn build_graph( thread_handles.push(handle); // Thread pool: Parse imports - let num_workers = thread::available_parallelism() - .map(|n| max(n.get() / 2, 1)) - .unwrap_or(4); - for _ in 0..num_workers { + let num_threads = num_threads_for_module_parsing(); + for _ in 0..num_threads { let receiver = found_module_receiver.clone(); let sender = parsed_module_sender.clone(); let error_sender = error_sender.clone(); @@ -141,10 +139,7 @@ struct ParsedModule { } fn discover_python_modules(package: &PackageSpec, sender: channel::Sender) { - let num_threads = thread::available_parallelism() - .map(|n| max(n.get() / 2, 1)) - .unwrap_or(4); - + let num_threads = num_threads_for_module_discovery(); let package_clone = package.clone(); WalkBuilder::new(&package.directory) @@ -324,3 +319,19 @@ fn assemble_graph( graph } + +/// Calculate the number of threads to use for module discovery. +/// Uses half of available parallelism, with a minimum of 1 and default of 4. +fn num_threads_for_module_discovery() -> usize { + thread::available_parallelism() + .map(|n| max(n.get() / 2, 1)) + .unwrap_or(4) +} + +/// Calculate the number of threads to use for module parsing. +/// Uses half of available parallelism, with a minimum of 1 and default of 4. +fn num_threads_for_module_parsing() -> usize { + thread::available_parallelism() + .map(|n| max(n.get() / 2, 1)) + .unwrap_or(4) +} From d3091f28efaa2471c3867b8b39a9f4571a83e109 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 16:05:59 +0100 Subject: [PATCH 11/19] Reduce channel capacities 1000 is still enough, and it allows deadlocks to show up in the test suite if we've gotten the logic wrong. --- rust/src/graph/builder/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 4852088a..c5f22297 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -47,8 +47,8 @@ pub fn build_graph( // Create channels for communication // This way we can start parsing moduels while we're still discovering them. - let (found_module_sender, found_module_receiver) = channel::bounded(10000); - let (parsed_module_sender, parser_module_receiver) = channel::bounded(10000); + let (found_module_sender, found_module_receiver) = channel::bounded(1000); + let (parsed_module_sender, parser_module_receiver) = channel::bounded(1000); let (error_sender, error_receiver) = channel::bounded(1); let mut thread_handles = Vec::new(); From 8034e9955c3074965f1c2ddb4154b50cd6368f55 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 16:10:44 +0100 Subject: [PATCH 12/19] Create ImportGraph.from_rustgraph --- src/grimp/application/graph.py | 6 ++++++ src/grimp/application/usecases.py | 4 +--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/grimp/application/graph.py b/src/grimp/application/graph.py index b6350c87..7bffddf4 100644 --- a/src/grimp/application/graph.py +++ b/src/grimp/application/graph.py @@ -39,6 +39,12 @@ def __init__(self) -> None: self._cached_modules: set[str] | None = None self._rustgraph = rust.Graph() + @classmethod + def from_rustgraph(cls, rustgraph: rust.Graph) -> ImportGraph: + graph = ImportGraph() + graph._rustgraph = rustgraph + return graph + # Mechanics # --------- diff --git a/src/grimp/application/usecases.py b/src/grimp/application/usecases.py index 4eb05440..215c4768 100644 --- a/src/grimp/application/usecases.py +++ b/src/grimp/application/usecases.py @@ -107,9 +107,7 @@ def build_graph_rust( ) # Wrap the rust graph in our ImportGraph wrapper - graph = ImportGraph() - graph._rustgraph = rust_graph - return graph + return ImportGraph.from_rustgraph(rust_graph) def _find_packages( From 70223f23bdc38ed5dc20b878483e136bef562862 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 18:21:04 +0100 Subject: [PATCH 13/19] More tweaks Because I can't resist --- rust/src/graph/builder/cache.rs | 120 +++++++++++------ rust/src/graph/builder/mod.rs | 227 +++++++++++++++++++------------- 2 files changed, 215 insertions(+), 132 deletions(-) diff --git a/rust/src/graph/builder/cache.rs b/rust/src/graph/builder/cache.rs index 624b3ca7..dfb92ce1 100644 --- a/rust/src/graph/builder/cache.rs +++ b/rust/src/graph/builder/cache.rs @@ -9,71 +9,117 @@ use crate::errors::{GrimpError, GrimpResult}; use crate::import_parsing::ImportedObject; #[derive(Debug, Clone, Encode, Decode)] -pub struct CachedImports { +struct CachedImports { mtime_secs: i64, imported_objects: Vec, } impl CachedImports { - pub fn new(mtime_secs: i64, imported_objects: Vec) -> Self { + fn new(mtime_secs: i64, imported_objects: Vec) -> Self { Self { mtime_secs, imported_objects, } } - pub fn mtime_secs(&self) -> i64 { + fn mtime_secs(&self) -> i64 { self.mtime_secs } - pub fn imported_objects(&self) -> &[ImportedObject] { + fn imported_objects(&self) -> &[ImportedObject] { &self.imported_objects } } -pub type ImportCache = HashMap; - -pub fn load_cache(cache_dir: &Path, package_name: &str) -> ImportCache { - let cache_file = cache_dir.join(format!("{}.imports.bincode", package_name)); +/// Cache for storing parsed import information indexed by module name. +#[derive(Debug, Clone)] +pub struct ImportsCache { + cache_dir: PathBuf, + package_name: String, + cache: HashMap, +} - if let Ok(mut file) = fs::File::open(&cache_file) { - let mut buffer = Vec::new(); - if file.read_to_end(&mut buffer).is_ok() - && let Ok(cache) = - bincode::decode_from_slice::(&buffer, bincode::config::standard()) - { - return cache.0; +impl ImportsCache { + /// Get cached imports for a module if they exist and the mtime matches. + pub fn get_imports(&self, module_name: &str, mtime_secs: i64) -> Option> { + let cached = self.cache.get(module_name)?; + if cached.mtime_secs() == mtime_secs { + Some(cached.imported_objects().to_vec()) + } else { + None } } - HashMap::new() -} - -pub fn save_cache(cache: &ImportCache, cache_dir: &Path, package_name: &str) -> GrimpResult<()> { - fs::create_dir_all(cache_dir).map_err(|e| GrimpError::CacheWriteError { - path: cache_dir.display().to_string(), - error: e.to_string(), - })?; - - let cache_file = cache_dir.join(format!("{}.imports.bincode", package_name)); + /// Store parsed imports for a module. + pub fn set_imports( + &mut self, + module_name: String, + mtime_secs: i64, + imports: Vec, + ) { + self.cache + .insert(module_name, CachedImports::new(mtime_secs, imports)); + } - let encoded = bincode::encode_to_vec(cache, bincode::config::standard()).map_err(|e| { - GrimpError::CacheWriteError { - path: cache_file.display().to_string(), + /// Save the cache to disk. + pub fn save(&self) -> GrimpResult<()> { + fs::create_dir_all(&self.cache_dir).map_err(|e| GrimpError::CacheWriteError { + path: self.cache_dir.display().to_string(), error: e.to_string(), - } - })?; + })?; + + let cache_file = self + .cache_dir + .join(format!("{}.imports.bincode", self.package_name)); - let mut file = fs::File::create(&cache_file).map_err(|e| GrimpError::CacheWriteError { - path: cache_file.display().to_string(), - error: e.to_string(), - })?; + let encoded = + bincode::encode_to_vec(&self.cache, bincode::config::standard()).map_err(|e| { + GrimpError::CacheWriteError { + path: cache_file.display().to_string(), + error: e.to_string(), + } + })?; - file.write_all(&encoded) - .map_err(|e| GrimpError::CacheWriteError { + let mut file = fs::File::create(&cache_file).map_err(|e| GrimpError::CacheWriteError { path: cache_file.display().to_string(), error: e.to_string(), })?; - Ok(()) + file.write_all(&encoded) + .map_err(|e| GrimpError::CacheWriteError { + path: cache_file.display().to_string(), + error: e.to_string(), + })?; + + Ok(()) + } +} + +/// Load cache from disk. +pub fn load_cache(cache_dir: &Path, package_name: &str) -> ImportsCache { + let cache_file = cache_dir.join(format!("{}.imports.bincode", package_name)); + + let cache_map = if let Ok(mut file) = fs::File::open(&cache_file) { + let mut buffer = Vec::new(); + if file.read_to_end(&mut buffer).is_ok() { + if let Ok(decoded) = bincode::decode_from_slice::, _>( + &buffer, + bincode::config::standard(), + ) { + decoded.0 + } else { + HashMap::new() + } + } else { + HashMap::new() + } + } else { + HashMap::new() + }; + + ImportsCache { + cache: cache_map, + cache_dir: cache_dir.to_path_buf(), + package_name: package_name.to_string(), + } } diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index c5f22297..3c9cea82 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -13,7 +13,7 @@ use crate::graph::Graph; use crate::import_parsing::{ImportedObject, parse_imports_from_code}; mod cache; -use cache::{CachedImports, ImportCache, load_cache, save_cache}; +use cache::{ImportsCache, load_cache}; mod utils; use utils::{ @@ -28,7 +28,7 @@ pub struct PackageSpec { } pub fn build_graph( - package: &PackageSpec, // TODO(peter) Support multiple packages + package: &PackageSpec, include_external_packages: bool, exclude_type_checking_imports: bool, cache_dir: Option<&PathBuf>, @@ -40,79 +40,8 @@ pub fn build_graph( )); } - // Load cache if available - let mut cache = cache_dir - .map(|dir| load_cache(dir, &package.name)) - .unwrap_or_default(); - - // Create channels for communication - // This way we can start parsing moduels while we're still discovering them. - let (found_module_sender, found_module_receiver) = channel::bounded(1000); - let (parsed_module_sender, parser_module_receiver) = channel::bounded(1000); - let (error_sender, error_receiver) = channel::bounded(1); - - let mut thread_handles = Vec::new(); - - // Thread 1: Discover modules - let package_clone = package.clone(); - let handle = thread::spawn(move || { - discover_python_modules(&package_clone, found_module_sender); - }); - thread_handles.push(handle); - - // Thread pool: Parse imports - let num_threads = num_threads_for_module_parsing(); - for _ in 0..num_threads { - let receiver = found_module_receiver.clone(); - let sender = parsed_module_sender.clone(); - let error_sender = error_sender.clone(); - let cache = cache.clone(); - let handle = thread::spawn(move || { - while let Ok(module) = receiver.recv() { - match parse_module_imports(&module, &cache) { - Ok(parsed) => { - let _ = sender.send(parsed); - } - Err(e) => { - // Channel has capacity of 1, since we only care to catch one error. - // Drop further errors. - let _ = error_sender.try_send(e); - } - } - } - }); - thread_handles.push(handle); - } - - // Close original receivers/senders so threads know when to stop - drop(parsed_module_sender); // Main thread will know when no more parsed modules - - // Collect parsed modules (this will continue until all parser threads finish and close their senders) - let mut parsed_modules = Vec::new(); - while let Ok(parsed) = parser_module_receiver.recv() { - parsed_modules.push(parsed); - } - - // Wait for all threads to complete - for handle in thread_handles { - handle.join().unwrap(); - } - - // Check if any errors occurred - if let Ok(error) = error_receiver.try_recv() { - return Err(error); - } - - // Update and save cache if cache_dir is set - if let Some(cache_dir) = cache_dir { - for parsed in &parsed_modules { - cache.insert( - parsed.module.path.clone(), - CachedImports::new(parsed.module.mtime_secs, parsed.imported_objects.clone()), - ); - } - save_cache(&cache, cache_dir, &package.name)?; - } + // Discover and parse all modules in parallel + let parsed_modules = discover_and_parse_modules(package, cache_dir)?; // Resolve imports and assemble graph let imports_by_module = resolve_imports( @@ -138,10 +67,111 @@ struct ParsedModule { imported_objects: Vec, } -fn discover_python_modules(package: &PackageSpec, sender: channel::Sender) { - let num_threads = num_threads_for_module_discovery(); - let package_clone = package.clone(); +/// Discover and parse all Python modules in a package using parallel processing. +/// +/// # Concurrency Model +/// +/// This function uses a pipeline architecture with three stages: +/// +/// 1. **Discovery Stage** (1 thread): +/// - Walks the package directory to find Python files +/// - Sends found modules to the parsing stage via `found_module_sender` +/// - Closes the channel when complete +/// +/// 2. **Parsing Stage** (N worker threads): +/// - Each thread receives modules from `found_module_receiver` +/// - Parses imports from each module (with caching) +/// - Sends parsed modules to the collection stage via `parsed_module_sender` +/// - Threads exit when `found_module_sender` is dropped (discovery complete) +/// +/// 3. **Collection Stage** (main thread): +/// - Receives parsed modules from `parser_module_receiver` +/// - Stops when all parser threads exit and drop their `parsed_module_sender` clones +/// +/// # Error Handling +/// +/// Parse errors are sent via `error_sender`. We only capture the first error and return it. +/// Subsequent errors are dropped since we fail fast on the first error. +/// +/// # Returns +/// +/// Returns a vector parsed modules, or the first error encountered. +fn discover_and_parse_modules( + package: &PackageSpec, + cache_dir: Option<&PathBuf>, +) -> GrimpResult> { + let thread_counts = calculate_thread_counts(); + + thread::scope(|scope| { + // Load cache if available + let mut cache = cache_dir.map(|dir| load_cache(dir, &package.name)); + // Create channels for the pipeline + let (found_module_sender, found_module_receiver) = channel::bounded(1000); + let (parsed_module_sender, parsed_module_receiver) = channel::bounded(1000); + let (error_sender, error_receiver) = channel::bounded(1); + + // Stage 1: Discovery thread + scope.spawn(|| { + discover_python_modules(package, thread_counts.module_discovery, found_module_sender); + }); + + // Stage 2: Parser thread pool + for _ in 0..thread_counts.module_parsing { + let receiver = found_module_receiver.clone(); + let sender = parsed_module_sender.clone(); + let error_sender = error_sender.clone(); + let cache = cache.clone(); + scope.spawn(move || { + while let Ok(module) = receiver.recv() { + match parse_module_imports(&module, cache.as_ref()) { + Ok(parsed) => { + let _ = sender.send(parsed); + } + Err(e) => { + let _ = error_sender.try_send(e); + } + } + } + }); + } + + // Close our copy of the sender so the receiver knows when all threads are done + drop(parsed_module_sender); + + // Stage 3: Collection (in main thread) + let mut parsed_modules = Vec::new(); + while let Ok(parsed) = parsed_module_receiver.recv() { + parsed_modules.push(parsed); + } + + // Check if any errors occurred + if let Ok(error) = error_receiver.try_recv() { + return Err(error); + } + + // Update and save cache if present + if let Some(cache) = &mut cache { + for parsed in &parsed_modules { + cache.set_imports( + parsed.module.name.clone(), + parsed.module.mtime_secs, + parsed.imported_objects.clone(), + ); + } + cache.save()?; + } + + Ok(parsed_modules) + }) +} + +fn discover_python_modules( + package: &PackageSpec, + num_threads: usize, + sender: channel::Sender, +) { + let package = package.clone(); WalkBuilder::new(&package.directory) .standard_filters(false) // Don't use gitignore or other filters .threads(num_threads) @@ -162,7 +192,7 @@ fn discover_python_modules(package: &PackageSpec, sender: channel::Sender GrimpResult { +fn parse_module_imports( + module: &FoundModule, + cache: Option<&ImportsCache>, +) -> GrimpResult { // Check if we have a cached version with matching mtime - if let Some(cached) = cache.get(&module.path) - && module.mtime_secs == cached.mtime_secs() + if let Some(cache) = cache + && let Some(imported_objects) = cache.get_imports(&module.name, module.mtime_secs) { // Cache hit - use cached imports return Ok(ParsedModule { module: module.clone(), - imported_objects: cached.imported_objects().to_vec(), + imported_objects, }); } @@ -320,18 +353,22 @@ fn assemble_graph( graph } -/// Calculate the number of threads to use for module discovery. -/// Uses half of available parallelism, with a minimum of 1 and default of 4. -fn num_threads_for_module_discovery() -> usize { - thread::available_parallelism() - .map(|n| max(n.get() / 2, 1)) - .unwrap_or(4) +/// Thread counts for parallel processing stages. +struct ThreadCounts { + module_discovery: usize, + module_parsing: usize, } -/// Calculate the number of threads to use for module parsing. -/// Uses half of available parallelism, with a minimum of 1 and default of 4. -fn num_threads_for_module_parsing() -> usize { - thread::available_parallelism() - .map(|n| max(n.get() / 2, 1)) - .unwrap_or(4) +/// Calculate the number of threads to use for parallel operations. +/// Uses 2/3 of available CPUs for both module discovery and parsing. +/// Since both module discovery and parsing involve some IO, it makes sense to +/// use slighly more threads than the available CPUs. +fn calculate_thread_counts() -> ThreadCounts { + let num_threads = thread::available_parallelism() + .map(|n| max((2 * n.get()) / 3, 1)) + .unwrap_or(4); + ThreadCounts { + module_discovery: num_threads, + module_parsing: num_threads, + } } From dd95e72731875d8bfef64c630d4e1c5b8985437e Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Mon, 10 Nov 2025 20:41:20 +0100 Subject: [PATCH 14/19] Add distill_external_module logic to handle namespace packages --- rust/src/graph/builder/mod.rs | 25 ++++++----- rust/src/graph/builder/utils.rs | 77 +++++++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 23 deletions(-) diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 3c9cea82..9fbb4c7d 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -17,7 +17,7 @@ use cache::{ImportsCache, load_cache}; mod utils; use utils::{ - ResolvedImport, is_internal, is_package, path_to_module_name, resolve_external_module, + ResolvedImport, distill_external_module, is_internal, is_package, path_to_module_name, resolve_internal_module, resolve_relative_import, }; @@ -28,7 +28,7 @@ pub struct PackageSpec { } pub fn build_graph( - package: &PackageSpec, + package: &PackageSpec, // TODO(peter) Support multiple packages include_external_packages: bool, exclude_type_checking_imports: bool, cache_dir: Option<&PathBuf>, @@ -48,6 +48,7 @@ pub fn build_graph( &parsed_modules, include_external_packages, exclude_type_checking_imports, + &HashSet::from([package.name.to_owned()]), ); Ok(assemble_graph(&imports_by_module, &package.name)) @@ -269,6 +270,7 @@ fn resolve_imports( parsed_modules: &[ParsedModule], include_external_packages: bool, exclude_type_checking_imports: bool, + packages: &HashSet, ) -> HashMap> { let all_modules: HashSet = parsed_modules .iter() @@ -304,14 +306,17 @@ fn resolve_imports( line_contents: imported_object.line_contents.clone(), }); } else if include_external_packages { - // It's an external module and we're including them - let external_module = resolve_external_module(&absolute_import_name); - resolved_imports.insert(ResolvedImport { - importer: parsed_module.module.name.to_string(), - imported: external_module, - line_number: imported_object.line_number, - line_contents: imported_object.line_contents.clone(), - }); + // Try to resolve as an external module + if let Some(external_module) = + distill_external_module(&absolute_import_name, packages) + { + resolved_imports.insert(ResolvedImport { + importer: parsed_module.module.name.to_string(), + imported: external_module, + line_number: imported_object.line_number, + line_contents: imported_object.line_contents.clone(), + }); + } } } diff --git a/rust/src/graph/builder/utils.rs b/rust/src/graph/builder/utils.rs index cd8c67f9..46d24e5a 100644 --- a/rust/src/graph/builder/utils.rs +++ b/rust/src/graph/builder/utils.rs @@ -20,12 +20,14 @@ pub fn is_package(module_path: &Path) -> bool { .unwrap_or(false) } +/// Check if a module is a descendant of another module. +pub fn is_descendant(module_name: &str, potential_ancestor: &str) -> bool { + module_name.starts_with(&format!("{}.", potential_ancestor)) +} + /// Check if module is internal pub fn is_internal(module_name: &str, package: &str) -> bool { - if module_name == package || module_name.starts_with(&format!("{}.", package)) { - return true; - } - false + module_name == package || is_descendant(module_name, package) } /// Convert module path to module name @@ -105,13 +107,62 @@ pub fn resolve_internal_module( None } -/// Get external module name -pub fn resolve_external_module(module_name: &str) -> String { - // For simplicity, just return the root module for external imports - // This matches the basic behavior from _distill_external_module - module_name - .split('.') - .next() - .unwrap_or(module_name) - .to_string() +/// Given a module that we already know is external, turn it into a module to add to the graph. +/// +/// The 'distillation' process involves removing any unwanted subpackages. For example, +/// django.models.db should be turned into simply django. +/// +/// The process is more complex for potential namespace packages, as it's not possible to +/// determine the portion package simply from name. Rather than adding the overhead of a +/// filesystem read, we just get the shallowest component that does not clash with an internal +/// module namespace. Take, for example, foo.blue.alpha.one. If one of the found +/// packages is foo.blue.beta, the module will be distilled to foo.blue.alpha. +/// Alternatively, if the found package is foo.green, the distilled module will +/// be foo.blue. +/// +/// Returns None if the module is a parent of one of the internal packages (doesn't make sense, +/// probably an import of a namespace package). +pub fn distill_external_module( + module_name: &str, + found_package_names: &HashSet, +) -> Option { + for found_package in found_package_names { + // If it's a module that is a parent of the package, return None + // as it doesn't make sense and is probably an import of a namespace package. + if is_descendant(found_package, module_name) { + return None; + } + } + + let module_root = module_name.split('.').next().unwrap(); + + // If it shares a namespace with an internal module, get the shallowest component that does + // not clash with an internal module namespace. + let mut candidate_portions: Vec = Vec::new(); + let mut sorted_found_packages: Vec<&String> = found_package_names.iter().collect(); + sorted_found_packages.sort(); + sorted_found_packages.reverse(); + + for found_package in sorted_found_packages { + if is_descendant(found_package, module_root) { + let mut internal_components: Vec<&str> = found_package.split('.').collect(); + let mut external_components: Vec<&str> = module_name.split('.').collect(); + let mut external_namespace_components: Vec<&str> = vec![]; + while external_components[0] == internal_components[0] { + external_namespace_components.push(external_components.remove(0)); + internal_components.remove(0); + } + external_namespace_components.push(external_components[0]); + candidate_portions.push(external_namespace_components.join(".")); + } + } + + if !candidate_portions.is_empty() { + // If multiple internal modules share a namespace with this module, use the deepest one + // as we know that that will be a namespace too. + candidate_portions.sort_by_key(|portion| portion.split('.').count()); + Some(candidate_portions.last().unwrap().clone()) + } else { + Some(module_root.to_string()) + } } From ae0ef856152983a1700712d5a63ef9c744ab56c5 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Tue, 11 Nov 2025 02:51:38 +0100 Subject: [PATCH 15/19] Support multiple packages, fix a few bugs --- rust/src/graph/builder/mod.rs | 194 ++++++++++++++++++------------ rust/src/graph/builder/utils.rs | 111 ++++++++++++++++- rust/src/graph_building.rs | 35 +++++- src/grimp/application/usecases.py | 17 +-- 4 files changed, 263 insertions(+), 94 deletions(-) diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 9fbb4c7d..8d3c34ff 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -18,7 +18,7 @@ use cache::{ImportsCache, load_cache}; mod utils; use utils::{ ResolvedImport, distill_external_module, is_internal, is_package, path_to_module_name, - resolve_internal_module, resolve_relative_import, + read_python_file, resolve_internal_module, resolve_relative_import, }; #[derive(Debug, Clone, new)] @@ -28,34 +28,38 @@ pub struct PackageSpec { } pub fn build_graph( - package: &PackageSpec, // TODO(peter) Support multiple packages + packages: &[PackageSpec], include_external_packages: bool, exclude_type_checking_imports: bool, cache_dir: Option<&PathBuf>, ) -> GrimpResult { - // Check if package directory exists - if !package.directory.exists() { - return Err(GrimpError::PackageDirectoryNotFound( - package.directory.display().to_string(), - )); + // Check all package directories exist + for package in packages { + if !package.directory.exists() { + return Err(GrimpError::PackageDirectoryNotFound( + package.directory.display().to_string(), + )); + } } + let package_names: HashSet = packages.iter().map(|p| p.name.clone()).collect(); - // Discover and parse all modules in parallel - let parsed_modules = discover_and_parse_modules(package, cache_dir)?; + // Discover and parse all modules from all packages in parallel + let parsed_modules = discover_and_parse_modules(packages, cache_dir)?; // Resolve imports and assemble graph let imports_by_module = resolve_imports( &parsed_modules, include_external_packages, exclude_type_checking_imports, - &HashSet::from([package.name.to_owned()]), + &package_names, ); - Ok(assemble_graph(&imports_by_module, &package.name)) + Ok(assemble_graph(&imports_by_module, &package_names)) } #[derive(Debug, Clone)] struct FoundModule { + package_name: String, name: String, path: PathBuf, is_package: bool, @@ -68,14 +72,14 @@ struct ParsedModule { imported_objects: Vec, } -/// Discover and parse all Python modules in a package using parallel processing. +/// Discover and parse all Python modules in one or more packages using parallel processing. /// /// # Concurrency Model /// /// This function uses a pipeline architecture with three stages: /// /// 1. **Discovery Stage** (1 thread): -/// - Walks the package directory to find Python files +/// - Walks all package directories to find Python files /// - Sends found modules to the parsing stage via `found_module_sender` /// - Closes the channel when complete /// @@ -98,14 +102,19 @@ struct ParsedModule { /// /// Returns a vector parsed modules, or the first error encountered. fn discover_and_parse_modules( - package: &PackageSpec, + packages: &[PackageSpec], cache_dir: Option<&PathBuf>, ) -> GrimpResult> { let thread_counts = calculate_thread_counts(); thread::scope(|scope| { - // Load cache if available - let mut cache = cache_dir.map(|dir| load_cache(dir, &package.name)); + // Load caches for all packages if available - store in HashMap by package name + let caches: Option> = cache_dir.map(|dir| { + packages + .iter() + .map(|pkg| (pkg.name.clone(), load_cache(dir, &pkg.name))) + .collect() + }); // Create channels for the pipeline let (found_module_sender, found_module_receiver) = channel::bounded(1000); @@ -114,7 +123,11 @@ fn discover_and_parse_modules( // Stage 1: Discovery thread scope.spawn(|| { - discover_python_modules(package, thread_counts.module_discovery, found_module_sender); + discover_python_modules( + packages, + thread_counts.module_discovery, + found_module_sender, + ); }); // Stage 2: Parser thread pool @@ -122,10 +135,15 @@ fn discover_and_parse_modules( let receiver = found_module_receiver.clone(); let sender = parsed_module_sender.clone(); let error_sender = error_sender.clone(); - let cache = cache.clone(); + let caches = caches.clone(); scope.spawn(move || { while let Ok(module) = receiver.recv() { - match parse_module_imports(&module, cache.as_ref()) { + // Look up the cache for this module's package + let cache = caches + .as_ref() + .and_then(|map| map.get(&module.package_name)); + + match parse_module_imports(&module, cache) { Ok(parsed) => { let _ = sender.send(parsed); } @@ -151,16 +169,20 @@ fn discover_and_parse_modules( return Err(error); } - // Update and save cache if present - if let Some(cache) = &mut cache { + // Update and save all caches + if let Some(mut caches) = caches { for parsed in &parsed_modules { - cache.set_imports( - parsed.module.name.clone(), - parsed.module.mtime_secs, - parsed.imported_objects.clone(), - ); + if let Some(cache) = caches.get_mut(&parsed.module.package_name) { + cache.set_imports( + parsed.module.name.clone(), + parsed.module.mtime_secs, + parsed.imported_objects.clone(), + ); + } + } + for cache in caches.values_mut() { + cache.save()?; } - cache.save()?; } Ok(parsed_modules) @@ -168,18 +190,30 @@ fn discover_and_parse_modules( } fn discover_python_modules( - package: &PackageSpec, + packages: &[PackageSpec], num_threads: usize, sender: channel::Sender, ) { - let package = package.clone(); - WalkBuilder::new(&package.directory) + let packages: Vec = packages.to_vec(); + + let mut builder = WalkBuilder::new(&packages[0].directory); + for package in &packages[1..] { + builder.add(&package.directory); + } + builder .standard_filters(false) // Don't use gitignore or other filters + .hidden(true) // Ignore hidden files/directories .threads(num_threads) .filter_entry(|entry| { - // Allow Python files + // Allow Python files (but skip files with multiple dots like dotted.module.py) if entry.file_type().is_some_and(|ft| ft.is_file()) { - return entry.path().extension().and_then(|s| s.to_str()) == Some("py"); + if entry.path().extension().and_then(|s| s.to_str()) == Some("py") { + // Check if filename has multiple dots (invalid Python module names) + if let Some(file_name) = entry.file_name().to_str() { + return file_name.matches('.').count() == 1; // Only the .py extension + } + } + return false; } // For directories, only descend if they contain __init__.py @@ -189,52 +223,59 @@ fn discover_python_modules( } false - }) - .build_parallel() - .run(|| { - let sender = sender.clone(); - let package = package.clone(); + }); - Box::new(move |entry| { - use ignore::WalkState; + builder.build_parallel().run(|| { + let sender = sender.clone(); + let packages = packages.clone(); - let entry = match entry { - Ok(e) => e, - Err(_) => return WalkState::Continue, - }; + Box::new(move |entry| { + use ignore::WalkState; - let path = entry.path(); + let entry = match entry { + Ok(e) => e, + Err(_) => return WalkState::Continue, + }; - // Only process files, not directories - if !path.is_file() { - return WalkState::Continue; - } + let path = entry.path(); - if let Some(module_name) = path_to_module_name(path, &package) { - let is_package = is_package(path); - - // Get mtime - let mtime_secs = fs::metadata(path) - .and_then(|m| m.modified()) - .ok() - .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) - .map(|d| d.as_secs() as i64) - .unwrap_or(0); - - let found_module = FoundModule { - name: module_name, - path: path.to_owned(), - is_package, - mtime_secs, - }; - - // Send module as soon as we discover it - let _ = sender.send(found_module); - } + // Only process files, not directories + if !path.is_file() { + return WalkState::Continue; + } - WalkState::Continue - }) - }); + // Find which package this file belongs to by checking if path starts with package directory + let package = packages + .iter() + .find(|pkg| path.starts_with(&pkg.directory)) + .unwrap(); + + if let Some(module_name) = path_to_module_name(path, package) { + let is_package = is_package(path); + + // Get mtime + let mtime_secs = fs::metadata(path) + .and_then(|m| m.modified()) + .ok() + .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) + .map(|d| d.as_secs() as i64) + .unwrap_or(0); + + let found_module = FoundModule { + package_name: package.name.clone(), + name: module_name, + path: path.to_owned(), + is_package, + mtime_secs, + }; + + // Send module as soon as we discover it + let _ = sender.send(found_module); + } + + WalkState::Continue + }) + }); } fn parse_module_imports( @@ -253,10 +294,7 @@ fn parse_module_imports( } // Cache miss or file modified - parse the file - let code = fs::read_to_string(&module.path).map_err(|e| GrimpError::FileReadError { - path: module.path.display().to_string(), - error: e.to_string(), - })?; + let code = read_python_file(&module.path)?; let imported_objects = parse_imports_from_code(&code, module.path.to_str().unwrap_or(""))?; @@ -328,7 +366,7 @@ fn resolve_imports( fn assemble_graph( imports_by_module: &HashMap>, - package_name: &str, + package_names: &HashSet, ) -> Graph { let mut graph = Graph::default(); @@ -339,7 +377,7 @@ fn assemble_graph( for import in imports { // Add the imported module - let imported_token = if is_internal(&import.imported, package_name) { + let imported_token = if is_internal(&import.imported, package_names) { graph.get_or_add_module(&import.imported).token() } else { graph.get_or_add_squashed_module(&import.imported).token() diff --git a/rust/src/graph/builder/utils.rs b/rust/src/graph/builder/utils.rs index 46d24e5a..9805f8bc 100644 --- a/rust/src/graph/builder/utils.rs +++ b/rust/src/graph/builder/utils.rs @@ -1,6 +1,11 @@ use std::collections::HashSet; +use std::fs; +use std::io::Read; use std::path::Path; +use encoding_rs::Encoding; + +use crate::errors::{GrimpError, GrimpResult}; use crate::graph::builder::PackageSpec; #[derive(Debug, Clone, Hash, Eq, PartialEq)] @@ -25,9 +30,11 @@ pub fn is_descendant(module_name: &str, potential_ancestor: &str) -> bool { module_name.starts_with(&format!("{}.", potential_ancestor)) } -/// Check if module is internal -pub fn is_internal(module_name: &str, package: &str) -> bool { - module_name == package || is_descendant(module_name, package) +/// Check if module is internal to any of the given packages +pub fn is_internal<'a>(module_name: &str, packages: impl IntoIterator) -> bool { + packages + .into_iter() + .any(|pkg| module_name == pkg || is_descendant(module_name, pkg)) } /// Convert module path to module name @@ -166,3 +173,101 @@ pub fn distill_external_module( Some(module_root.to_string()) } } + +/// Read a Python source file with proper encoding detection. +/// +/// Python PEP 263 specifies that encoding can be declared in the first or second line +/// in the format: `# coding: ` or `# -*- coding: -*-` +/// +/// This function: +/// 1. Reads the file as bytes +/// 2. Checks the first two lines for an encoding declaration +/// 3. Decodes the file using the detected encoding (or UTF-8 as default) +pub fn read_python_file(path: &Path) -> GrimpResult { + // Read file as bytes + let mut file = fs::File::open(path).map_err(|e| GrimpError::FileReadError { + path: path.display().to_string(), + error: e.to_string(), + })?; + + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes) + .map_err(|e| GrimpError::FileReadError { + path: path.display().to_string(), + error: e.to_string(), + })?; + + // Detect encoding from first two lines + let encoding = detect_python_encoding(&bytes); + + // Decode using detected encoding + let (decoded, _encoding_used, had_errors) = encoding.decode(&bytes); + + if had_errors { + return Err(GrimpError::FileReadError { + path: path.display().to_string(), + error: format!("Failed to decode file with encoding {}", encoding.name()), + }); + } + + Ok(decoded.into_owned()) +} + +/// Detect Python source file encoding from the first two lines. +/// +/// Looks for patterns like: +/// - `# coding: ` +/// - `# -*- coding: -*-` +/// - `# coding=` +fn detect_python_encoding(bytes: &[u8]) -> &'static Encoding { + // Read first two lines as ASCII (encoding declarations must be ASCII-compatible) + let mut line_count = 0; + let mut line_start = 0; + + for (i, &byte) in bytes.iter().enumerate() { + if byte == b'\n' { + line_count += 1; + if line_count <= 2 { + // Check this line for encoding declaration + let line = &bytes[line_start..i]; + if let Some(encoding) = extract_encoding_from_line(line) { + return encoding; + } + line_start = i + 1; + } else { + break; + } + } + } + + // Default to UTF-8 + encoding_rs::UTF_8 +} + +/// Extract encoding from a single line if it contains an encoding declaration. +fn extract_encoding_from_line(line: &[u8]) -> Option<&'static Encoding> { + // Convert line to string (should be ASCII for encoding declarations) + let line_str = std::str::from_utf8(line).ok()?; + + // Look for "coding:" or "coding=" + if let Some(pos) = line_str + .find("coding:") + .or_else(|| line_str.find("coding=")) + { + let after_coding = &line_str[pos + 7..]; // Skip "coding:" or "coding=" + + // Extract encoding name (alphanumeric, dash, underscore until whitespace or special char) + let encoding_name: String = after_coding + .trim_start() + .chars() + .take_while(|c| c.is_alphanumeric() || *c == '-' || *c == '_') + .collect(); + + if !encoding_name.is_empty() { + // Try to get the encoding + return Encoding::for_label(encoding_name.as_bytes()); + } + } + + None +} diff --git a/rust/src/graph_building.rs b/rust/src/graph_building.rs index bcd4d631..3551540b 100644 --- a/rust/src/graph_building.rs +++ b/rust/src/graph_building.rs @@ -1,6 +1,8 @@ use pyo3::prelude::*; +use pyo3::types::PyModule; use std::path::PathBuf; +use crate::errors::GrimpError; use crate::graph::GraphWrapper; use crate::graph::builder::{PackageSpec, build_graph}; @@ -21,19 +23,40 @@ impl PyPackageSpec { } #[pyfunction] -#[pyo3(signature = (package, include_external_packages=false, exclude_type_checking_imports=false, cache_dir=None))] +#[pyo3(signature = (packages, include_external_packages=false, exclude_type_checking_imports=false, cache_dir=None))] pub fn build_graph_rust( - package: &PyPackageSpec, + py: Python, + packages: Vec, include_external_packages: bool, exclude_type_checking_imports: bool, cache_dir: Option, ) -> PyResult { let cache_path = cache_dir.map(PathBuf::from); - let graph = build_graph( - &package.inner, + + // Extract the inner PackageSpec from each PyPackageSpec + let package_specs: Vec = packages.iter().map(|p| p.inner.clone()).collect(); + + let graph_result = build_graph( + &package_specs, include_external_packages, exclude_type_checking_imports, cache_path.as_ref(), - )?; - Ok(GraphWrapper::from_graph(graph)) + ); + + match graph_result { + Ok(graph) => Ok(GraphWrapper::from_graph(graph)), + Err(GrimpError::ParseError { + module_filename, + line_number, + text, + .. + }) => { + // Import the Python SourceSyntaxError from grimp.exceptions + let exceptions_module = PyModule::import(py, "grimp.exceptions")?; + let source_syntax_error = exceptions_module.getattr("SourceSyntaxError")?; + let exception = source_syntax_error.call1((module_filename, line_number, text))?; + Err(PyErr::from_value(exception)) + } + Err(e) => Err(e.into()), + } } diff --git a/src/grimp/application/usecases.py b/src/grimp/application/usecases.py index 215c4768..dff2771d 100644 --- a/src/grimp/application/usecases.py +++ b/src/grimp/application/usecases.py @@ -83,13 +83,16 @@ def build_graph_rust( file_system: AbstractFileSystem = settings.FILE_SYSTEM package_finder: AbstractPackageFinder = settings.PACKAGE_FINDER - # Determine the package directory - package_directory = package_finder.determine_package_directory( - package_name=package_name, file_system=file_system - ) + # Collect all package names + all_package_names = [package_name] + list(additional_package_names) - # Create package spec and build the graph - package_spec = rust.PackageSpec(package_name, package_directory) + # Create package specs for all packages + package_specs = [] + for package_name in all_package_names: + package_directory = package_finder.determine_package_directory( + package_name=package_name, file_system=file_system + ) + package_specs.append(rust.PackageSpec(package_name, package_directory)) # Handle cache_dir cache_dir_arg: str | None = None @@ -100,7 +103,7 @@ def build_graph_rust( # Build the graph rust_graph = rust.build_graph_rust( - package_spec, + package_specs, include_external_packages=include_external_packages, exclude_type_checking_imports=exclude_type_checking_imports, cache_dir=cache_dir_arg, From 0accb849c004886aaf3f1ba2a0e375ccbb303f1f Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Tue, 11 Nov 2025 02:51:54 +0100 Subject: [PATCH 16/19] TEMP Run all functional tests for build_graph_rust --- tests/functional/conftest.py | 12 ++++ tests/functional/test_build_and_use_graph.py | 64 ++++++++++++------- ...build_and_use_graph_with_multiple_roots.py | 9 +-- .../test_build_graph_on_real_packages.py | 6 +- tests/functional/test_encoding_handling.py | 11 ++-- tests/functional/test_error_handling.py | 6 +- tests/functional/test_namespace_packages.py | 36 +++++++---- 7 files changed, 94 insertions(+), 50 deletions(-) create mode 100644 tests/functional/conftest.py diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py new file mode 100644 index 00000000..5092c386 --- /dev/null +++ b/tests/functional/conftest.py @@ -0,0 +1,12 @@ +import pytest + +import grimp + + +@pytest.fixture(params=["python", "rust"], ids=["python", "rust"]) +def build_graph(request): + """Fixture that provides both Python and Rust graph building implementations.""" + if request.param == "python": + return grimp.build_graph + else: + return grimp.build_graph_rust diff --git a/tests/functional/test_build_and_use_graph.py b/tests/functional/test_build_and_use_graph.py index ed151598..7d42cbd4 100644 --- a/tests/functional/test_build_and_use_graph.py +++ b/tests/functional/test_build_and_use_graph.py @@ -1,7 +1,5 @@ -from grimp import build_graph import pytest - """ For ease of reference, these are the imports of all the files: @@ -30,7 +28,9 @@ # --------- -def test_modules(): +def test_modules( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) assert graph.modules == { @@ -53,7 +53,9 @@ def test_modules(): } -def test_add_module(): +def test_add_module( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) number_of_modules = len(graph.modules) @@ -62,7 +64,9 @@ def test_add_module(): assert number_of_modules + 1 == len(graph.modules) -def test_remove_module(): +def test_remove_module( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) number_of_modules = len(graph.modules) @@ -71,7 +75,9 @@ def test_remove_module(): assert number_of_modules - 1 == len(graph.modules) -def test_add_and_remove_import(): +def test_add_and_remove_import( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) a = "testpackage.one.delta.blue" b = "testpackage.two.alpha" @@ -91,7 +97,9 @@ def test_add_and_remove_import(): # ----------- -def test_find_children(): +def test_find_children( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) assert graph.find_children("testpackage.one") == { @@ -102,7 +110,9 @@ def test_find_children(): } -def test_find_descendants(): +def test_find_descendants( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) assert graph.find_descendants("testpackage.one") == { @@ -118,7 +128,9 @@ def test_find_descendants(): # -------------- -def test_find_modules_directly_imported_by(): +def test_find_modules_directly_imported_by( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) result = graph.find_modules_directly_imported_by("testpackage.utils") @@ -126,7 +138,9 @@ def test_find_modules_directly_imported_by(): assert {"testpackage.one", "testpackage.two.alpha"} == result -def test_find_modules_that_directly_import(): +def test_find_modules_that_directly_import( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) result = graph.find_modules_that_directly_import("testpackage.one.alpha") @@ -141,7 +155,9 @@ def test_find_modules_that_directly_import(): } == result -def test_direct_import_exists(): +def test_direct_import_exists( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) assert False is graph.direct_import_exists( @@ -155,7 +171,9 @@ def test_direct_import_exists(): ) -def test_get_import_details(): +def test_get_import_details( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) assert [ @@ -173,7 +191,7 @@ def test_get_import_details(): class TestPathExists: - def test_as_packages_false(self): + def test_as_packages_false(self, build_graph): graph = build_graph("testpackage", cache_dir=None) assert not graph.chain_exists( @@ -182,7 +200,7 @@ def test_as_packages_false(self): assert graph.chain_exists(imported="testpackage.one.alpha", importer="testpackage.utils") - def test_as_packages_true(self): + def test_as_packages_true(self, build_graph): graph = build_graph("testpackage", cache_dir=None) assert graph.chain_exists( @@ -194,7 +212,9 @@ def test_as_packages_true(self): ) -def test_find_shortest_chain(): +def test_find_shortest_chain( + build_graph, +): graph = build_graph("testpackage", cache_dir=None) assert ( @@ -246,7 +266,7 @@ def test_find_shortest_chain(): ), ], ) -def test_find_shortest_chains(as_packages: bool | None, expected_result: set[tuple]): +def test_find_shortest_chains(build_graph, as_packages: bool | None, expected_result: set[tuple]): importer = "testpackage.three" imported = "testpackage.one.alpha" @@ -263,7 +283,7 @@ def test_find_shortest_chains(as_packages: bool | None, expected_result: set[tup class TestFindDownstreamModules: - def test_as_package_false(self): + def test_as_package_false(self, build_graph): graph = build_graph("testpackage", cache_dir=None) result = graph.find_downstream_modules("testpackage.one.alpha") @@ -281,7 +301,7 @@ def test_as_package_false(self): "testpackage.three.gamma", } == result - def test_as_package_true(self): + def test_as_package_true(self, build_graph): graph = build_graph("testpackage", cache_dir=None) result = graph.find_downstream_modules("testpackage.one", as_package=True) @@ -299,7 +319,7 @@ def test_as_package_true(self): class TestFindUpstreamModules: - def test_as_package_false(self): + def test_as_package_false(self, build_graph): graph = build_graph("testpackage", cache_dir=None) assert graph.find_upstream_modules("testpackage.one.alpha") == set() @@ -310,7 +330,7 @@ def test_as_package_false(self): "testpackage.one.alpha", } - def test_as_package_true(self): + def test_as_package_true(self, build_graph): graph = build_graph("testpackage", cache_dir=None) assert graph.find_upstream_modules("testpackage.two", as_package=True) == { @@ -321,13 +341,13 @@ def test_as_package_true(self): class TestExcludeTypeCheckingImports: - def test_exclude_false(self): + def test_exclude_false(self, build_graph): graph = build_graph("testpackage", cache_dir=None, exclude_type_checking_imports=False) assert True is graph.direct_import_exists( importer="testpackage.one.beta", imported="testpackage.two.alpha" ) - def test_exclude_true(self): + def test_exclude_true(self, build_graph): graph = build_graph("testpackage", cache_dir=None, exclude_type_checking_imports=True) assert False is graph.direct_import_exists( importer="testpackage.one.beta", imported="testpackage.two.alpha" diff --git a/tests/functional/test_build_and_use_graph_with_multiple_roots.py b/tests/functional/test_build_and_use_graph_with_multiple_roots.py index 4c0efdcf..48fe6839 100644 --- a/tests/functional/test_build_and_use_graph_with_multiple_roots.py +++ b/tests/functional/test_build_and_use_graph_with_multiple_roots.py @@ -1,5 +1,4 @@ import pytest # type: ignore -from grimp import build_graph """ For ease of reference, these are the imports of all the files: @@ -24,7 +23,9 @@ @pytest.mark.parametrize("root_packages", PACKAGES_IN_DIFFERENT_ORDERS) class TestBuildGraph: - def test_graph_has_correct_modules_regardless_of_package_order(self, root_packages): + def test_graph_has_correct_modules_regardless_of_package_order( + self, build_graph, root_packages + ): graph = build_graph(*root_packages, cache_dir=None) assert graph.modules == { @@ -38,7 +39,7 @@ def test_graph_has_correct_modules_regardless_of_package_order(self, root_packag "rootpackagegreen.two", } - def test_stores_import_within_package(self, root_packages): + def test_stores_import_within_package(self, build_graph, root_packages): graph = build_graph(*root_packages, cache_dir=None) assert [ @@ -52,7 +53,7 @@ def test_stores_import_within_package(self, root_packages): importer="rootpackageblue.two", imported="rootpackageblue.one.alpha" ) - def test_stores_import_between_root_packages(self, root_packages): + def test_stores_import_between_root_packages(self, build_graph, root_packages): graph = build_graph(*root_packages, cache_dir=None) assert [ diff --git a/tests/functional/test_build_graph_on_real_packages.py b/tests/functional/test_build_graph_on_real_packages.py index 0f92f77d..431ee6f6 100644 --- a/tests/functional/test_build_graph_on_real_packages.py +++ b/tests/functional/test_build_graph_on_real_packages.py @@ -24,7 +24,9 @@ def test_build_graph_on_real_package(package_name, snapshot): ) def test_nominate_cycle_breakers_django(package_name, snapshot): graph = grimp.build_graph("django") - cycle_breakers = graph.nominate_cycle_breakers(package_name) - assert cycle_breakers == snapshot + + graph_from_rust = grimp.build_graph("django") + cycle_breakers_from_rust = graph_from_rust.nominate_cycle_breakers(package_name) + assert cycle_breakers_from_rust == cycle_breakers diff --git a/tests/functional/test_encoding_handling.py b/tests/functional/test_encoding_handling.py index 21a8f5a2..72761a9f 100644 --- a/tests/functional/test_encoding_handling.py +++ b/tests/functional/test_encoding_handling.py @@ -1,11 +1,8 @@ -import grimp - - -def test_build_graph_of_non_ascii_source(): +def test_build_graph_of_non_ascii_source(build_graph): """ Tests we can cope with non ascii Python source files. """ - graph = grimp.build_graph("encodingpackage", cache_dir=None) + graph = build_graph("encodingpackage", cache_dir=None) result = graph.get_import_details( importer="encodingpackage.importer", imported="encodingpackage.imported" @@ -21,11 +18,11 @@ def test_build_graph_of_non_ascii_source(): ] == result -def test_build_graph_of_non_utf8_source(): +def test_build_graph_of_non_utf8_source(build_graph): """ Tests we can cope with non UTF-8 Python source files. """ - graph = grimp.build_graph("encodingpackage", cache_dir=None) + graph = build_graph("encodingpackage", cache_dir=None) result = graph.get_import_details( importer="encodingpackage.shift_jis_importer", imported="encodingpackage.imported" diff --git a/tests/functional/test_error_handling.py b/tests/functional/test_error_handling.py index 45635986..014fb7a4 100644 --- a/tests/functional/test_error_handling.py +++ b/tests/functional/test_error_handling.py @@ -3,10 +3,10 @@ import pytest # type: ignore -from grimp import build_graph, exceptions +from grimp import exceptions -def test_syntax_error_includes_module(): +def test_syntax_error_includes_module(build_graph): dirname = os.path.dirname(__file__) filename = os.path.abspath( os.path.join(dirname, "..", "assets", "syntaxerrorpackage", "foo", "one.py") @@ -21,7 +21,7 @@ def test_syntax_error_includes_module(): assert expected_exception == excinfo.value -def test_missing_root_init_file(): +def test_missing_root_init_file(build_graph): with pytest.raises( exceptions.NamespacePackageEncountered, match=re.escape( diff --git a/tests/functional/test_namespace_packages.py b/tests/functional/test_namespace_packages.py index 0d0d8b83..2f97e0ea 100644 --- a/tests/functional/test_namespace_packages.py +++ b/tests/functional/test_namespace_packages.py @@ -1,6 +1,6 @@ import pytest # type: ignore -from grimp import build_graph, exceptions +from grimp import exceptions """ For ease of reference, these are the imports of all the files: @@ -18,7 +18,9 @@ """ -def test_build_graph_for_namespace(): +def test_build_graph_for_namespace( + build_graph, +): with pytest.raises(exceptions.NamespacePackageEncountered): build_graph("mynamespace", cache_dir=None) @@ -40,13 +42,15 @@ def test_build_graph_for_namespace(): ), ), ) -def test_modules_for_namespace_child(package, expected_modules): +def test_modules_for_namespace_child(build_graph, package, expected_modules): graph = build_graph(package, cache_dir=None) assert graph.modules == expected_modules -def test_modules_for_multiple_namespace_children(): +def test_modules_for_multiple_namespace_children( + build_graph, +): graph = build_graph("mynamespace.green", "mynamespace.blue", cache_dir=None) assert graph.modules == GREEN_MODULES | BLUE_MODULES @@ -73,7 +77,7 @@ def test_modules_for_multiple_namespace_children(): ), ) def test_external_packages_handling( - packages, expected_internal_modules, expected_external_modules + build_graph, packages, expected_internal_modules, expected_external_modules ): graph = build_graph(*packages, include_external_packages=True, cache_dir=None) @@ -81,7 +85,9 @@ def test_external_packages_handling( assert all(graph.is_module_squashed(m) for m in expected_external_modules) -def test_import_within_namespace_child(): +def test_import_within_namespace_child( + build_graph, +): graph = build_graph("mynamespace.blue", cache_dir=None) assert graph.direct_import_exists( @@ -89,7 +95,9 @@ def test_import_within_namespace_child(): ) -def test_import_between_namespace_children(): +def test_import_between_namespace_children( + build_graph, +): graph = build_graph("mynamespace.blue", "mynamespace.green", cache_dir=None) assert graph.direct_import_exists( @@ -104,7 +112,7 @@ def test_import_between_namespace_children(): "package_name", ("nestednamespace", "nestednamespace.foo", "nestednamespace.foo.alpha"), ) -def test_build_graph_for_nested_namespace(package_name): +def test_build_graph_for_nested_namespace(build_graph, package_name): with pytest.raises(exceptions.NamespacePackageEncountered): build_graph(package_name, cache_dir=None) @@ -134,13 +142,15 @@ def test_build_graph_for_nested_namespace(package_name): ), ), ) -def test_modules_for_nested_namespace_child(package, expected_modules): +def test_modules_for_nested_namespace_child(build_graph, package, expected_modules): graph = build_graph(package, cache_dir=None) assert graph.modules == expected_modules -def test_import_within_nested_namespace_child(): +def test_import_within_nested_namespace_child( + build_graph, +): graph = build_graph( "nestednamespace.foo.alpha.blue", cache_dir=None, @@ -152,7 +162,9 @@ def test_import_within_nested_namespace_child(): ) -def test_import_between_nested_namespace_children(): +def test_import_between_nested_namespace_children( + build_graph, +): graph = build_graph( "nestednamespace.foo.alpha.blue", "nestednamespace.foo.alpha.green", @@ -204,7 +216,7 @@ def test_import_between_nested_namespace_children(): ), ) def test_external_packages_handling_for_nested_namespaces( - packages, expected_internal_modules, expected_external_modules + build_graph, packages, expected_internal_modules, expected_external_modules ): graph = build_graph(*packages, include_external_packages=True, cache_dir=None) From 8a0b7a8a85bd2a99a3613acc6b8d3c68116abc6c Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Tue, 11 Nov 2025 09:49:20 +0100 Subject: [PATCH 17/19] Cache improvements --- rust/src/graph/builder/cache.rs | 125 ------------------ rust/src/graph/builder/imports_cache.rs | 169 ++++++++++++++++++++++++ rust/src/graph/builder/mod.rs | 47 +++---- src/grimp/application/usecases.py | 11 +- 4 files changed, 191 insertions(+), 161 deletions(-) delete mode 100644 rust/src/graph/builder/cache.rs create mode 100644 rust/src/graph/builder/imports_cache.rs diff --git a/rust/src/graph/builder/cache.rs b/rust/src/graph/builder/cache.rs deleted file mode 100644 index dfb92ce1..00000000 --- a/rust/src/graph/builder/cache.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::collections::HashMap; -use std::fs; -use std::io::{Read as _, Write as _}; -use std::path::{Path, PathBuf}; - -use bincode::{Decode, Encode}; - -use crate::errors::{GrimpError, GrimpResult}; -use crate::import_parsing::ImportedObject; - -#[derive(Debug, Clone, Encode, Decode)] -struct CachedImports { - mtime_secs: i64, - imported_objects: Vec, -} - -impl CachedImports { - fn new(mtime_secs: i64, imported_objects: Vec) -> Self { - Self { - mtime_secs, - imported_objects, - } - } - - fn mtime_secs(&self) -> i64 { - self.mtime_secs - } - - fn imported_objects(&self) -> &[ImportedObject] { - &self.imported_objects - } -} - -/// Cache for storing parsed import information indexed by module name. -#[derive(Debug, Clone)] -pub struct ImportsCache { - cache_dir: PathBuf, - package_name: String, - cache: HashMap, -} - -impl ImportsCache { - /// Get cached imports for a module if they exist and the mtime matches. - pub fn get_imports(&self, module_name: &str, mtime_secs: i64) -> Option> { - let cached = self.cache.get(module_name)?; - if cached.mtime_secs() == mtime_secs { - Some(cached.imported_objects().to_vec()) - } else { - None - } - } - - /// Store parsed imports for a module. - pub fn set_imports( - &mut self, - module_name: String, - mtime_secs: i64, - imports: Vec, - ) { - self.cache - .insert(module_name, CachedImports::new(mtime_secs, imports)); - } - - /// Save the cache to disk. - pub fn save(&self) -> GrimpResult<()> { - fs::create_dir_all(&self.cache_dir).map_err(|e| GrimpError::CacheWriteError { - path: self.cache_dir.display().to_string(), - error: e.to_string(), - })?; - - let cache_file = self - .cache_dir - .join(format!("{}.imports.bincode", self.package_name)); - - let encoded = - bincode::encode_to_vec(&self.cache, bincode::config::standard()).map_err(|e| { - GrimpError::CacheWriteError { - path: cache_file.display().to_string(), - error: e.to_string(), - } - })?; - - let mut file = fs::File::create(&cache_file).map_err(|e| GrimpError::CacheWriteError { - path: cache_file.display().to_string(), - error: e.to_string(), - })?; - - file.write_all(&encoded) - .map_err(|e| GrimpError::CacheWriteError { - path: cache_file.display().to_string(), - error: e.to_string(), - })?; - - Ok(()) - } -} - -/// Load cache from disk. -pub fn load_cache(cache_dir: &Path, package_name: &str) -> ImportsCache { - let cache_file = cache_dir.join(format!("{}.imports.bincode", package_name)); - - let cache_map = if let Ok(mut file) = fs::File::open(&cache_file) { - let mut buffer = Vec::new(); - if file.read_to_end(&mut buffer).is_ok() { - if let Ok(decoded) = bincode::decode_from_slice::, _>( - &buffer, - bincode::config::standard(), - ) { - decoded.0 - } else { - HashMap::new() - } - } else { - HashMap::new() - } - } else { - HashMap::new() - }; - - ImportsCache { - cache: cache_map, - cache_dir: cache_dir.to_path_buf(), - package_name: package_name.to_string(), - } -} diff --git a/rust/src/graph/builder/imports_cache.rs b/rust/src/graph/builder/imports_cache.rs new file mode 100644 index 00000000..fea03fd0 --- /dev/null +++ b/rust/src/graph/builder/imports_cache.rs @@ -0,0 +1,169 @@ +use std::collections::HashMap; +use std::fs; +use std::io::{Read as _, Write as _}; +use std::path::{Path, PathBuf}; + +use bincode::{Decode, Encode}; + +use crate::errors::{GrimpError, GrimpResult}; +use crate::import_parsing::ImportedObject; + +/// Cache for storing parsed import information. +#[derive(Debug, Clone)] +pub struct ImportsCache { + cache_dir: PathBuf, + // Map of package name to package imports. + cache: HashMap>, +} + +impl ImportsCache { + /// Get cached imports for a module if they exist and the mtime matches. + pub fn get_imports( + &self, + package_name: &str, + module_name: &str, + mtime_secs: i64, + ) -> Option> { + let package_cache = self.cache.get(package_name)?; + let cached = package_cache.get(module_name)?; + if cached.mtime_secs() == mtime_secs { + Some(cached.imported_objects().to_vec()) + } else { + None + } + } + + /// Store parsed imports for a module. + pub fn set_imports( + &mut self, + package_name: String, + module_name: String, + mtime_secs: i64, + imports: Vec, + ) { + self.cache + .entry(package_name) + .or_default() + .insert(module_name, CachedImports::new(mtime_secs, imports)); + } + + /// Save cache to disk. + pub fn save(&self) -> GrimpResult<()> { + fs::create_dir_all(&self.cache_dir).map_err(|e| GrimpError::CacheWriteError { + path: self.cache_dir.display().to_string(), + error: e.to_string(), + })?; + + // Write marker files if they don't exist + self.write_marker_files_if_missing()?; + + for (package_name, package_cache) in &self.cache { + let cache_file = self + .cache_dir + .join(format!("{}.imports.bincode", package_name)); + + let encoded = bincode::encode_to_vec(package_cache, bincode::config::standard()) + .map_err(|e| GrimpError::CacheWriteError { + path: cache_file.display().to_string(), + error: e.to_string(), + })?; + + let mut file = + fs::File::create(&cache_file).map_err(|e| GrimpError::CacheWriteError { + path: cache_file.display().to_string(), + error: e.to_string(), + })?; + + file.write_all(&encoded) + .map_err(|e| GrimpError::CacheWriteError { + path: cache_file.display().to_string(), + error: e.to_string(), + })?; + } + + Ok(()) + } + + /// Write marker files (.gitignore and CACHEDIR.TAG) if they don't already exist. + fn write_marker_files_if_missing(&self) -> GrimpResult<()> { + let marker_files = [ + (".gitignore", "# Automatically created by Grimp.\n*"), + ( + "CACHEDIR.TAG", + "Signature: 8a477f597d28d172789f06886806bc55\n\ + # This file is a cache directory tag automatically created by Grimp.\n\ + # For information about cache directory tags see https://bford.info/cachedir/", + ), + ]; + + for (filename, contents) in &marker_files { + let full_path = self.cache_dir.join(filename); + if !full_path.exists() { + fs::write(&full_path, contents).map_err(|e| GrimpError::CacheWriteError { + path: full_path.display().to_string(), + error: e.to_string(), + })?; + } + } + + Ok(()) + } + + /// Load cache from disk. + pub fn load(cache_dir: &Path, package_names: &[String]) -> Self { + let mut cache = HashMap::new(); + + for package_name in package_names { + let cache_file = cache_dir.join(format!("{}.imports.bincode", package_name)); + + let package_cache = if let Ok(mut file) = fs::File::open(&cache_file) { + let mut buffer = Vec::new(); + if file.read_to_end(&mut buffer).is_ok() { + if let Ok(decoded) = bincode::decode_from_slice::< + HashMap, + _, + >(&buffer, bincode::config::standard()) + { + decoded.0 + } else { + HashMap::new() + } + } else { + HashMap::new() + } + } else { + HashMap::new() + }; + + cache.insert(package_name.clone(), package_cache); + } + + ImportsCache { + cache, + cache_dir: cache_dir.to_path_buf(), + } + } +} + +#[derive(Debug, Clone, Encode, Decode)] +struct CachedImports { + mtime_secs: i64, + imported_objects: Vec, +} + +impl CachedImports { + fn new(mtime_secs: i64, imported_objects: Vec) -> Self { + Self { + mtime_secs, + imported_objects, + } + } + + fn mtime_secs(&self) -> i64 { + self.mtime_secs + } + + fn imported_objects(&self) -> &[ImportedObject] { + &self.imported_objects + } +} diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 8d3c34ff..95424044 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -12,8 +12,8 @@ use crate::errors::{GrimpError, GrimpResult}; use crate::graph::Graph; use crate::import_parsing::{ImportedObject, parse_imports_from_code}; -mod cache; -use cache::{ImportsCache, load_cache}; +mod imports_cache; +use imports_cache::ImportsCache; mod utils; use utils::{ @@ -108,12 +108,10 @@ fn discover_and_parse_modules( let thread_counts = calculate_thread_counts(); thread::scope(|scope| { - // Load caches for all packages if available - store in HashMap by package name - let caches: Option> = cache_dir.map(|dir| { - packages - .iter() - .map(|pkg| (pkg.name.clone(), load_cache(dir, &pkg.name))) - .collect() + // Load cache for all packages if available + let mut cache: Option = cache_dir.map(|dir| { + let package_names: Vec = packages.iter().map(|p| p.name.clone()).collect(); + ImportsCache::load(dir, &package_names) }); // Create channels for the pipeline @@ -135,15 +133,10 @@ fn discover_and_parse_modules( let receiver = found_module_receiver.clone(); let sender = parsed_module_sender.clone(); let error_sender = error_sender.clone(); - let caches = caches.clone(); + let cache = cache.clone(); scope.spawn(move || { while let Ok(module) = receiver.recv() { - // Look up the cache for this module's package - let cache = caches - .as_ref() - .and_then(|map| map.get(&module.package_name)); - - match parse_module_imports(&module, cache) { + match parse_module_imports(&module, cache.as_ref()) { Ok(parsed) => { let _ = sender.send(parsed); } @@ -169,20 +162,17 @@ fn discover_and_parse_modules( return Err(error); } - // Update and save all caches - if let Some(mut caches) = caches { + // Update and save cache + if let Some(cache) = &mut cache { for parsed in &parsed_modules { - if let Some(cache) = caches.get_mut(&parsed.module.package_name) { - cache.set_imports( - parsed.module.name.clone(), - parsed.module.mtime_secs, - parsed.imported_objects.clone(), - ); - } - } - for cache in caches.values_mut() { - cache.save()?; + cache.set_imports( + parsed.module.package_name.clone(), + parsed.module.name.clone(), + parsed.module.mtime_secs, + parsed.imported_objects.clone(), + ); } + cache.save()?; } Ok(parsed_modules) @@ -284,7 +274,8 @@ fn parse_module_imports( ) -> GrimpResult { // Check if we have a cached version with matching mtime if let Some(cache) = cache - && let Some(imported_objects) = cache.get_imports(&module.name, module.mtime_secs) + && let Some(imported_objects) = + cache.get_imports(&module.package_name, &module.name, module.mtime_secs) { // Cache hit - use cached imports return Ok(ParsedModule { diff --git a/src/grimp/application/usecases.py b/src/grimp/application/usecases.py index dff2771d..70396b00 100644 --- a/src/grimp/application/usecases.py +++ b/src/grimp/application/usecases.py @@ -80,17 +80,12 @@ def build_graph_rust( """ from grimp import _rustgrimp as rust # type: ignore[attr-defined] - file_system: AbstractFileSystem = settings.FILE_SYSTEM - package_finder: AbstractPackageFinder = settings.PACKAGE_FINDER - - # Collect all package names - all_package_names = [package_name] + list(additional_package_names) - # Create package specs for all packages + all_package_names = [package_name] + list(additional_package_names) package_specs = [] for package_name in all_package_names: - package_directory = package_finder.determine_package_directory( - package_name=package_name, file_system=file_system + package_directory = settings.PACKAGE_FINDER.determine_package_directory( + package_name=package_name, file_system=settings.FILE_SYSTEM ) package_specs.append(rust.PackageSpec(package_name, package_directory)) From 76a732ce0f3318a5c171859355f79c5a0afc363e Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Tue, 11 Nov 2025 10:10:34 +0100 Subject: [PATCH 18/19] Move python file reading utils to own file --- rust/src/graph/builder/mod.rs | 5 +- rust/src/graph/builder/read_python_file.rs | 103 +++++++++++++++++++++ rust/src/graph/builder/utils.rs | 103 --------------------- 3 files changed, 107 insertions(+), 104 deletions(-) create mode 100644 rust/src/graph/builder/read_python_file.rs diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 95424044..326310e6 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -18,9 +18,12 @@ use imports_cache::ImportsCache; mod utils; use utils::{ ResolvedImport, distill_external_module, is_internal, is_package, path_to_module_name, - read_python_file, resolve_internal_module, resolve_relative_import, + resolve_internal_module, resolve_relative_import, }; +mod read_python_file; +use read_python_file::read_python_file; + #[derive(Debug, Clone, new)] pub struct PackageSpec { name: String, diff --git a/rust/src/graph/builder/read_python_file.rs b/rust/src/graph/builder/read_python_file.rs new file mode 100644 index 00000000..0d99d76d --- /dev/null +++ b/rust/src/graph/builder/read_python_file.rs @@ -0,0 +1,103 @@ +use std::{fs, io::Read, path::Path}; + +use encoding_rs::Encoding; + +use crate::errors::{GrimpError, GrimpResult}; + +/// Read a Python source file with proper encoding detection. +/// +/// Python PEP 263 specifies that encoding can be declared in the first or second line +/// in the format: `# coding: ` or `# -*- coding: -*-` +/// +/// This function: +/// 1. Reads the file as bytes +/// 2. Checks the first two lines for an encoding declaration +/// 3. Decodes the file using the detected encoding (or UTF-8 as default) +pub fn read_python_file(path: &Path) -> GrimpResult { + // Read file as bytes + let mut file = fs::File::open(path).map_err(|e| GrimpError::FileReadError { + path: path.display().to_string(), + error: e.to_string(), + })?; + + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes) + .map_err(|e| GrimpError::FileReadError { + path: path.display().to_string(), + error: e.to_string(), + })?; + + // Detect encoding from first two lines + let encoding = detect_python_encoding(&bytes); + + // Decode using detected encoding + let (decoded, _encoding_used, had_errors) = encoding.decode(&bytes); + + if had_errors { + return Err(GrimpError::FileReadError { + path: path.display().to_string(), + error: format!("Failed to decode file with encoding {}", encoding.name()), + }); + } + + Ok(decoded.into_owned()) +} + +/// Detect Python source file encoding from the first two lines. +/// +/// Looks for patterns like: +/// - `# coding: ` +/// - `# -*- coding: -*-` +/// - `# coding=` +fn detect_python_encoding(bytes: &[u8]) -> &'static Encoding { + // Read first two lines as ASCII (encoding declarations must be ASCII-compatible) + let mut line_count = 0; + let mut line_start = 0; + + for (i, &byte) in bytes.iter().enumerate() { + if byte == b'\n' { + line_count += 1; + if line_count <= 2 { + // Check this line for encoding declaration + let line = &bytes[line_start..i]; + if let Some(encoding) = extract_encoding_from_line(line) { + return encoding; + } + line_start = i + 1; + } else { + break; + } + } + } + + // Default to UTF-8 + encoding_rs::UTF_8 +} + +/// Extract encoding from a single line if it contains an encoding declaration. +fn extract_encoding_from_line(line: &[u8]) -> Option<&'static Encoding> { + // Convert line to string (should be ASCII for encoding declarations) + let line_str = std::str::from_utf8(line).ok()?; + + // Look for "coding:" or "coding=" + if let Some(pos) = line_str + .find("coding:") + .or_else(|| line_str.find("coding=")) + { + let after_coding = &line_str[pos + 7..]; // Skip "coding:" or "coding=" + + // Extract encoding name (alphanumeric, dash, underscore until whitespace or special char) + let encoding_name: String = after_coding + .trim_start() + .chars() + .take_while(|c| c.is_alphanumeric() || *c == '-' || *c == '_') + .collect(); + + if !encoding_name.is_empty() { + // Try to get the encoding + return Encoding::for_label(encoding_name.as_bytes()); + } + } + + None +} diff --git a/rust/src/graph/builder/utils.rs b/rust/src/graph/builder/utils.rs index 9805f8bc..b25003fd 100644 --- a/rust/src/graph/builder/utils.rs +++ b/rust/src/graph/builder/utils.rs @@ -1,11 +1,6 @@ use std::collections::HashSet; -use std::fs; -use std::io::Read; use std::path::Path; -use encoding_rs::Encoding; - -use crate::errors::{GrimpError, GrimpResult}; use crate::graph::builder::PackageSpec; #[derive(Debug, Clone, Hash, Eq, PartialEq)] @@ -173,101 +168,3 @@ pub fn distill_external_module( Some(module_root.to_string()) } } - -/// Read a Python source file with proper encoding detection. -/// -/// Python PEP 263 specifies that encoding can be declared in the first or second line -/// in the format: `# coding: ` or `# -*- coding: -*-` -/// -/// This function: -/// 1. Reads the file as bytes -/// 2. Checks the first two lines for an encoding declaration -/// 3. Decodes the file using the detected encoding (or UTF-8 as default) -pub fn read_python_file(path: &Path) -> GrimpResult { - // Read file as bytes - let mut file = fs::File::open(path).map_err(|e| GrimpError::FileReadError { - path: path.display().to_string(), - error: e.to_string(), - })?; - - let mut bytes = Vec::new(); - file.read_to_end(&mut bytes) - .map_err(|e| GrimpError::FileReadError { - path: path.display().to_string(), - error: e.to_string(), - })?; - - // Detect encoding from first two lines - let encoding = detect_python_encoding(&bytes); - - // Decode using detected encoding - let (decoded, _encoding_used, had_errors) = encoding.decode(&bytes); - - if had_errors { - return Err(GrimpError::FileReadError { - path: path.display().to_string(), - error: format!("Failed to decode file with encoding {}", encoding.name()), - }); - } - - Ok(decoded.into_owned()) -} - -/// Detect Python source file encoding from the first two lines. -/// -/// Looks for patterns like: -/// - `# coding: ` -/// - `# -*- coding: -*-` -/// - `# coding=` -fn detect_python_encoding(bytes: &[u8]) -> &'static Encoding { - // Read first two lines as ASCII (encoding declarations must be ASCII-compatible) - let mut line_count = 0; - let mut line_start = 0; - - for (i, &byte) in bytes.iter().enumerate() { - if byte == b'\n' { - line_count += 1; - if line_count <= 2 { - // Check this line for encoding declaration - let line = &bytes[line_start..i]; - if let Some(encoding) = extract_encoding_from_line(line) { - return encoding; - } - line_start = i + 1; - } else { - break; - } - } - } - - // Default to UTF-8 - encoding_rs::UTF_8 -} - -/// Extract encoding from a single line if it contains an encoding declaration. -fn extract_encoding_from_line(line: &[u8]) -> Option<&'static Encoding> { - // Convert line to string (should be ASCII for encoding declarations) - let line_str = std::str::from_utf8(line).ok()?; - - // Look for "coding:" or "coding=" - if let Some(pos) = line_str - .find("coding:") - .or_else(|| line_str.find("coding=")) - { - let after_coding = &line_str[pos + 7..]; // Skip "coding:" or "coding=" - - // Extract encoding name (alphanumeric, dash, underscore until whitespace or special char) - let encoding_name: String = after_coding - .trim_start() - .chars() - .take_while(|c| c.is_alphanumeric() || *c == '-' || *c == '_') - .collect(); - - if !encoding_name.is_empty() { - // Try to get the encoding - return Encoding::for_label(encoding_name.as_bytes()); - } - } - - None -} From efa9e4bdd58c1b565901eb94de0f0ab84b86ab53 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Tue, 11 Nov 2025 17:34:10 +0100 Subject: [PATCH 19/19] Demo rust unit test with temp file system --- rust/Cargo.lock | 200 +++++++++++++++- rust/Cargo.toml | 3 + rust/src/graph/builder/mod.rs | 59 +++++ rust/src/lib.rs | 3 + rust/src/test_utils.rs | 431 ++++++++++++++++++++++++++++++++++ 5 files changed, 694 insertions(+), 2 deletions(-) create mode 100644 rust/src/test_utils.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 104ffb89..51f21a15 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -12,10 +12,12 @@ dependencies = [ "crossbeam", "derive-new", "encoding_rs", + "filetime", "getset", "ignore", "indexmap 2.11.0", "itertools 0.14.0", + "map-macro", "parameterized", "pyo3", "rayon", @@ -30,6 +32,7 @@ dependencies = [ "slotmap", "string-interner", "tap", + "tempfile", "thiserror", "unindent", ] @@ -206,6 +209,34 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "filetime" +version = "0.2.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.60.2", +] + [[package]] name = "foldhash" version = "0.1.5" @@ -232,6 +263,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "getset" version = "0.1.6" @@ -362,12 +405,35 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libredox" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +dependencies = [ + "bitflags", + "libc", + "redox_syscall", +] + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "log" version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +[[package]] +name = "map-macro" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb950a42259642e5a3483115aca87eebed2a64886993463af9c9739c205b8d3a" + [[package]] name = "memchr" version = "2.7.5" @@ -564,6 +630,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "rand" version = "0.8.5" @@ -591,7 +663,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.16", ] [[package]] @@ -614,6 +686,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.11.2" @@ -715,6 +796,19 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "ryu" version = "1.0.20" @@ -829,6 +923,19 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "2.0.16" @@ -959,13 +1066,22 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "winapi-util" version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -974,6 +1090,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.61.2" @@ -983,6 +1108,77 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "zerocopy" version = "0.8.26" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 96cb5a57..fd1097cc 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -43,3 +43,6 @@ default = ["extension-module"] [dev-dependencies] parameterized = "2.0.0" serde_json = "1.0.143" +tempfile = "3.14.0" +filetime = "0.2.26" +map-macro = "0.3.0" diff --git a/rust/src/graph/builder/mod.rs b/rust/src/graph/builder/mod.rs index 326310e6..c9c2298f 100644 --- a/rust/src/graph/builder/mod.rs +++ b/rust/src/graph/builder/mod.rs @@ -26,7 +26,9 @@ use read_python_file::read_python_file; #[derive(Debug, Clone, new)] pub struct PackageSpec { + #[new(into)] name: String, + #[new(into)] directory: PathBuf, } @@ -409,3 +411,60 @@ fn calculate_thread_counts() -> ThreadCounts { module_parsing: num_threads, } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{DEFAULT_MTIME, TempFileSystemBuilder}; + use map_macro::hash_set; + + #[test] + fn test_discover_modules_happy_path() { + const SOME_MTIME: i64 = 12340000; + + let temp_fs = TempFileSystemBuilder::new( + r#" + mypackage/ + __init__.py + not-a-python-file.txt + .hidden + foo/ + __init__.py + one.py + two/ + __init__.py + green.py + blue.py + "#, + ) + .with_file_mtime_map([("mypackage/foo/one.py", SOME_MTIME)]) + .build() + .unwrap(); + + let result = discover_and_parse_modules( + &[PackageSpec::new("mypackage", temp_fs.join("mypackage"))], + None, + ) + .unwrap(); + + assert_eq!(result.len(), 6); + assert_eq!( + result + .iter() + .map(|p| ( + p.module.name.as_str(), + p.module.is_package, + p.module.mtime_secs + )) + .collect::>(), + hash_set! { + ("mypackage", true, DEFAULT_MTIME), + ("mypackage.foo", true, DEFAULT_MTIME), + ("mypackage.foo.one", false, SOME_MTIME), + ("mypackage.foo.two", true, DEFAULT_MTIME), + ("mypackage.foo.two.green", false, DEFAULT_MTIME), + ("mypackage.foo.two.blue", false, DEFAULT_MTIME), + } + ); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 35aa1736..a1a2f422 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -9,6 +9,9 @@ mod import_scanning; pub mod module_expressions; mod module_finding; +#[cfg(test)] +pub mod test_utils; + use pyo3::prelude::*; #[pymodule] diff --git a/rust/src/test_utils.rs b/rust/src/test_utils.rs new file mode 100644 index 00000000..93ac23d7 --- /dev/null +++ b/rust/src/test_utils.rs @@ -0,0 +1,431 @@ +use filetime::{FileTime, set_file_mtime}; +use std::collections::HashMap; +use std::fs; +use std::io::Write; +use std::path::{Path, PathBuf}; +use tempfile::TempDir; + +/// Default mtime for files (in seconds since Unix epoch) +pub const DEFAULT_MTIME: i64 = 10000; + +/// A builder for creating temporary directory structures for testing. +/// +/// # Example +/// +/// ``` +/// use _rustgrimp::test_utils::TempFileSystemBuilder; +/// +/// let temp_fs = TempFileSystemBuilder::new(r#" +/// mypackage/ +/// __init__.py +/// foo/ +/// __init__.py +/// one.py +/// "#) +/// .with_file_content_map([ +/// ("mypackage/foo/one.py", "from . import two"), +/// ]) +/// .with_file_mtime_map([ +/// ("mypackage/foo/one.py", 12340000), +/// ]) +/// .build() +/// .unwrap(); +/// +/// let package_dir = temp_fs.join("mypackage"); +/// ``` +pub struct TempFileSystemBuilder { + contents: String, + content_map: HashMap, + mtime_overrides: HashMap, +} + +impl TempFileSystemBuilder { + /// Create a new builder with the directory structure. + /// + /// The string should be formatted with directories ending in `/` and files without. + /// Indentation determines the hierarchy. + /// + /// # Example + /// + /// ``` + /// let builder = TempFileSystemBuilder::new(r#" + /// mypackage/ + /// __init__.py + /// foo/ + /// __init__.py + /// one.py + /// "#); + /// ``` + pub fn new(contents: &str) -> Self { + Self { + contents: contents.to_string(), + content_map: HashMap::new(), + mtime_overrides: HashMap::new(), + } + } + + /// Set the content for a specific file (relative path). + /// + /// # Arguments + /// + /// * `path` - Relative path to the file from the temp directory root + /// * `content` - The text content of the file + /// + /// # Example + /// + /// ``` + /// let builder = TempFileSystemBuilder::new("...") + /// .with_file_content("mypackage/foo/one.py", "from . import two"); + /// ``` + pub fn with_file_content(mut self, path: &str, content: &str) -> Self { + self.content_map + .insert(path.to_string(), content.to_string()); + self + } + + /// Set the content for multiple files at once. + /// + /// # Arguments + /// + /// * `content_map` - An iterator of (path, content) pairs + /// + /// # Example + /// + /// ``` + /// let builder = TempFileSystemBuilder::new("...") + /// .with_file_content_map([ + /// ("mypackage/foo/one.py", "from . import two"), + /// ("mypackage/foo/two.py", "x = 1"), + /// ]); + /// ``` + pub fn with_file_content_map( + mut self, + content_map: impl IntoIterator, impl Into)>, + ) -> Self { + self.content_map + .extend(content_map.into_iter().map(|(k, v)| (k.into(), v.into()))); + self + } + + /// Set a custom modification time for a specific file (relative path). + /// + /// # Arguments + /// + /// * `path` - Relative path to the file from the temp directory root + /// * `mtime` - Modification time in seconds since Unix epoch + /// + /// # Example + /// + /// ``` + /// let builder = TempFileSystemBuilder::new("...") + /// .with_mtime("mypackage/foo/one.py", 12340000); + /// ``` + pub fn with_file_mtime(mut self, path: &str, mtime: i64) -> Self { + self.mtime_overrides.insert(path.to_string(), mtime); + self + } + + /// Set custom modification times for multiple files at once. + /// + /// # Arguments + /// + /// * `mtime_map` - An iterator of (path, mtime) pairs + /// + /// # Example + /// + /// ``` + /// let builder = TempFileSystemBuilder::new("...") + /// .with_file_mtime_map([ + /// ("mypackage/foo/one.py", 12340000), + /// ("mypackage/foo/two.py", 12350000), + /// ]); + /// ``` + pub fn with_file_mtime_map( + mut self, + mtime_map: impl IntoIterator, i64)>, + ) -> Self { + self.mtime_overrides + .extend(mtime_map.into_iter().map(|(k, v)| (k.into(), v))); + self + } + + /// Build the temporary file system + pub fn build(self) -> std::io::Result { + let temp_dir = TempDir::new()?; + + // Create the directory structure + Self::create_structure(temp_dir.path(), &self.contents)?; + + // Write file contents from content_map + for (relative_path, content) in &self.content_map { + let full_path = temp_dir.path().join(relative_path); + + // Ensure parent directory exists + if let Some(parent) = full_path.parent() { + fs::create_dir_all(parent)?; + } + + let mut file = fs::File::create(&full_path)?; + file.write_all(content.as_bytes())?; + + // Set default mtime for files created via content_map + let default_filetime = FileTime::from_unix_time(DEFAULT_MTIME, 0); + set_file_mtime(&full_path, default_filetime)?; + } + + // Apply mtime overrides (must come after content_map writes) + for (relative_path, mtime) in &self.mtime_overrides { + let full_path = temp_dir.path().join(relative_path); + if full_path.exists() { + let filetime = FileTime::from_unix_time(*mtime, 0); + set_file_mtime(&full_path, filetime)?; + } + } + + Ok(TempFileSystem { temp_dir }) + } + + fn create_structure(base_path: &Path, contents: &str) -> std::io::Result<()> { + let lines: Vec<&str> = contents + .lines() + .map(|l| l.trim_end()) + .filter(|l| !l.is_empty()) + .collect(); + + if lines.is_empty() { + return Ok(()); + } + + // Calculate minimum indentation to dedent + let min_indent = lines + .iter() + .filter(|l| !l.trim().is_empty()) + .map(|l| l.len() - l.trim_start().len()) + .min() + .unwrap_or(0); + + // Parse the structure + let mut stack: Vec<(usize, PathBuf)> = vec![(0, base_path.to_path_buf())]; + + for line in lines { + let trimmed = line.trim_start(); + if trimmed.is_empty() { + continue; + } + + let indent = line.len() - trimmed.len() - min_indent; + let indent_level = indent / 4; // Assume 4 spaces per level + + // Pop stack until we find the parent + while stack.len() > indent_level + 1 { + stack.pop(); + } + + let parent_path = &stack.last().unwrap().1; + let name = trimmed.trim(); + + if name.ends_with('/') { + // Directory + let dir_name = name.trim_end_matches('/'); + let dir_path = parent_path.join(dir_name); + fs::create_dir_all(&dir_path)?; + stack.push((indent_level, dir_path)); + } else { + // File + let file_path = parent_path.join(name); + let mut file = fs::File::create(&file_path)?; + // Create empty file + file.write_all(b"")?; + + // Set default mtime + let default_filetime = FileTime::from_unix_time(DEFAULT_MTIME, 0); + set_file_mtime(&file_path, default_filetime)?; + } + } + + Ok(()) + } +} + +/// A temporary directory structure for testing. +pub struct TempFileSystem { + temp_dir: TempDir, +} + +impl TempFileSystem { + /// Get the root path of the temporary directory + pub fn path(&self) -> &Path { + self.temp_dir.path() + } + + /// Join a path component to the root path + /// + /// # Example + /// + /// ``` + /// let temp_fs = TempFileSystemBuilder::new("...").build().unwrap(); + /// let package_dir = temp_fs.join("mypackage"); + /// ``` + pub fn join(&self, path: impl AsRef) -> PathBuf { + self.temp_dir.path().join(path) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_simple_structure() { + let temp_fs = TempFileSystemBuilder::new( + r#" + mypackage/ + __init__.py + foo/ + __init__.py + one.py + "#, + ) + .build() + .unwrap(); + + let base = temp_fs.path(); + assert!(base.join("mypackage").is_dir()); + assert!(base.join("mypackage/__init__.py").is_file()); + assert!(base.join("mypackage/foo").is_dir()); + assert!(base.join("mypackage/foo/__init__.py").is_file()); + assert!(base.join("mypackage/foo/one.py").is_file()); + } + + #[test] + fn test_create_with_file_content() { + let temp_fs = TempFileSystemBuilder::new( + r#" + mypackage/ + __init__.py + foo/ + __init__.py + one.py + "#, + ) + .with_file_content("mypackage/foo/one.py", "from . import two") + .build() + .unwrap(); + + let one_py = temp_fs.path().join("mypackage/foo/one.py"); + let content = fs::read_to_string(&one_py).unwrap(); + assert_eq!(content, "from . import two"); + } + + #[test] + fn test_create_with_content_map() { + let mut content_map = HashMap::new(); + content_map.insert("mypackage/foo/one.py".to_string(), "import sys".to_string()); + content_map.insert("mypackage/foo/two.py".to_string(), "x = 1".to_string()); + + let temp_fs = TempFileSystemBuilder::new( + r#" + mypackage/ + __init__.py + foo/ + __init__.py + one.py + two.py + "#, + ) + .with_file_content_map(content_map) + .build() + .unwrap(); + + let one_py = temp_fs.path().join("mypackage/foo/one.py"); + let content = fs::read_to_string(&one_py).unwrap(); + assert_eq!(content, "import sys"); + + let two_py = temp_fs.path().join("mypackage/foo/two.py"); + let content = fs::read_to_string(&two_py).unwrap(); + assert_eq!(content, "x = 1"); + } + + #[test] + fn test_create_with_custom_mtimes() { + let temp_fs = TempFileSystemBuilder::new( + r#" + mypackage/ + __init__.py + foo/ + __init__.py + one.py + "#, + ) + .with_file_mtime("mypackage/foo/one.py", 12340000) + .build() + .unwrap(); + + let one_py = temp_fs.path().join("mypackage/foo/one.py"); + let metadata = fs::metadata(&one_py).unwrap(); + let mtime = FileTime::from_last_modification_time(&metadata); + assert_eq!(mtime.unix_seconds(), 12340000); + + let init_py = temp_fs.path().join("mypackage/__init__.py"); + let metadata = fs::metadata(&init_py).unwrap(); + let mtime = FileTime::from_last_modification_time(&metadata); + assert_eq!(mtime.unix_seconds(), DEFAULT_MTIME); + } + + #[test] + fn test_nested_directories() { + let temp_fs = TempFileSystemBuilder::new( + r#" + mypackage/ + __init__.py + foo/ + __init__.py + two/ + __init__.py + green.py + blue.py + "#, + ) + .build() + .unwrap(); + + let base = temp_fs.path(); + assert!(base.join("mypackage/foo/two").is_dir()); + assert!(base.join("mypackage/foo/two/green.py").is_file()); + assert!(base.join("mypackage/foo/two/blue.py").is_file()); + } + + #[test] + fn test_builder_chaining() { + let temp_fs = TempFileSystemBuilder::new( + r#" + mypackage/ + __init__.py + foo/ + __init__.py + one.py + two.py + "#, + ) + .with_file_mtime("mypackage/foo/one.py", 11111111) + .with_file_mtime("mypackage/foo/two.py", 22222222) + .with_file_content("mypackage/foo/one.py", "# one") + .with_file_content("mypackage/foo/two.py", "# two") + .build() + .unwrap(); + + let one_py = temp_fs.path().join("mypackage/foo/one.py"); + let metadata = fs::metadata(&one_py).unwrap(); + let mtime = FileTime::from_last_modification_time(&metadata); + assert_eq!(mtime.unix_seconds(), 11111111); + let content = fs::read_to_string(&one_py).unwrap(); + assert_eq!(content, "# one"); + + let two_py = temp_fs.path().join("mypackage/foo/two.py"); + let metadata = fs::metadata(&two_py).unwrap(); + let mtime = FileTime::from_last_modification_time(&metadata); + assert_eq!(mtime.unix_seconds(), 22222222); + let content = fs::read_to_string(&two_py).unwrap(); + assert_eq!(content, "# two"); + } +}