Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions rust/src/caching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,34 @@ use crate::errors::{GrimpError, GrimpResult};
use crate::filesystem::get_file_system_boxed;
use crate::import_scanning::{DirectImport, imports_by_module_to_py};
use crate::module_finding::Module;
use pyo3::types::PyDict;
use pyo3::{Bound, PyAny, PyResult, Python, pyfunction};
use pyo3::types::PyAnyMethods;
use pyo3::types::{PyDict, PySet};
use pyo3::types::{PyDictMethods, PySetMethods};
use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python, pyfunction};
use std::collections::{HashMap, HashSet};

/// Writes the cache file containing all the imports for a given package.
/// Args:
/// - filename: str
/// - imports_by_module: dict[Module, Set[DirectImport]]
/// - file_system: The file system interface to use. (A BasicFileSystem.)
#[pyfunction]
pub fn write_cache_data_map_file<'py>(
filename: &str,
imports_by_module: Bound<'py, PyDict>,
file_system: Bound<'py, PyAny>,
) -> PyResult<()> {
let mut file_system_boxed = get_file_system_boxed(&file_system)?;

let ImportsByModule(imports_by_module_rust) = imports_by_module.extract()?;

let file_contents = serialize_imports_by_module(&imports_by_module_rust);

file_system_boxed.write(filename, &file_contents)?;

Ok(())
}

/// Reads the cache file containing all the imports for a given package.
/// Args:
/// - filename: str
Expand All @@ -26,6 +50,52 @@ pub fn read_cache_data_map_file<'py>(
Ok(imports_by_module_to_py(py, imports_by_module))
}

/// A newtype wrapper for HashMap<Module, HashSet<DirectImport>> that implements FromPyObject.
pub struct ImportsByModule(pub HashMap<Module, HashSet<DirectImport>>);

impl<'py> FromPyObject<'py> for ImportsByModule {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let py_dict = ob.downcast::<PyDict>()?;
let mut imports_by_module_rust = HashMap::new();

for (py_key, py_value) in py_dict.iter() {
let module: Module = py_key.extract()?;
let py_set = py_value.downcast::<PySet>()?;
let mut hashset: HashSet<DirectImport> = HashSet::new();
for element in py_set.iter() {
let direct_import: DirectImport = element.extract()?;
hashset.insert(direct_import);
}
imports_by_module_rust.insert(module, hashset);
}

Ok(ImportsByModule(imports_by_module_rust))
}
}

fn serialize_imports_by_module(
imports_by_module: &HashMap<Module, HashSet<DirectImport>>,
) -> String {
let raw_map: HashMap<&str, Vec<(&str, usize, &str)>> = imports_by_module
.iter()
.map(|(module, imports)| {
let imports_vec: Vec<(&str, usize, &str)> = imports
.iter()
.map(|import| {
(
import.imported.as_str(),
import.line_number,
import.line_contents.as_str(),
)
})
.collect();
(module.name.as_str(), imports_vec)
})
.collect();

serde_json::to_string(&raw_map).expect("Failed to serialize to JSON")
}

pub fn parse_json_to_map(
json_str: &str,
filename: &str,
Expand Down
54 changes: 42 additions & 12 deletions rust/src/filesystem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use regex::Regex;
use std::collections::HashMap;
use std::ffi::OsStr;
use std::fs;
use std::fs::File;
use std::io::prelude::*;
use std::path::{Path, PathBuf};
use std::sync::LazyLock;
use std::sync::{Arc, LazyLock, Mutex};
use unindent::unindent;

static ENCODING_RE: LazyLock<Regex> =
Expand All @@ -22,17 +24,19 @@ pub trait FileSystem: Send + Sync {
fn exists(&self, file_name: &str) -> bool;

fn read(&self, file_name: &str) -> PyResult<String>;

fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()>;
}

#[derive(Clone)]
#[pyclass]
pub struct RealBasicFileSystem {}
struct RealBasicFileSystem {}

// Implements a BasicFileSystem (defined in grimp.application.ports.filesystem.BasicFileSystem)
// that actually reads files.
#[pyclass(name = "RealBasicFileSystem")]
pub struct PyRealBasicFileSystem {
pub inner: RealBasicFileSystem,
inner: RealBasicFileSystem,
}

impl FileSystem for RealBasicFileSystem {
Expand Down Expand Up @@ -129,6 +133,16 @@ impl FileSystem for RealBasicFileSystem {
})
}
}

fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()> {
let file_path: PathBuf = file_name.into();
if let Some(patent_dir) = file_path.parent() {
fs::create_dir_all(patent_dir)?;
}
File::create(file_path)?
.write_all(contents.as_bytes())
.map_err(Into::into)
}
}

#[pymethods]
Expand Down Expand Up @@ -161,19 +175,23 @@ impl PyRealBasicFileSystem {
fn read(&self, file_name: &str) -> PyResult<String> {
self.inner.read(file_name)
}

fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()> {
self.inner.write(file_name, contents)
}
}

type FileSystemContents = HashMap<String, String>;

#[derive(Clone)]
pub struct FakeBasicFileSystem {
contents: Box<FileSystemContents>,
struct FakeBasicFileSystem {
contents: Arc<Mutex<FileSystemContents>>,
}

// Implements BasicFileSystem (defined in grimp.application.ports.filesystem.BasicFileSystem).
#[pyclass(name = "FakeBasicFileSystem")]
pub struct PyFakeBasicFileSystem {
pub inner: FakeBasicFileSystem,
inner: FakeBasicFileSystem,
}

impl FakeBasicFileSystem {
Expand All @@ -190,7 +208,7 @@ impl FakeBasicFileSystem {
parsed_contents.extend(unindented_map);
};
Ok(FakeBasicFileSystem {
contents: Box::new(parsed_contents),
contents: Arc::new(Mutex::new(parsed_contents)),
})
}
}
Expand Down Expand Up @@ -232,17 +250,25 @@ impl FileSystem for FakeBasicFileSystem {

/// Checks if a file or directory exists within the file system.
fn exists(&self, file_name: &str) -> bool {
self.contents.contains_key(file_name)
self.contents.lock().unwrap().contains_key(file_name)
}

fn read(&self, file_name: &str) -> PyResult<String> {
match self.contents.get(file_name) {
Some(file_name) => Ok(file_name.clone()),
let contents = self.contents.lock().unwrap();
match contents.get(file_name) {
Some(file_contents) => Ok(file_contents.clone()),
None => Err(PyFileNotFoundError::new_err(format!(
"No such file: {file_name}"
))),
}
}

#[allow(unused_variables)]
fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()> {
let mut contents_mut = self.contents.lock().unwrap();
contents_mut.insert(file_name.to_string(), contents.to_string());
Ok(())
}
}

#[pymethods]
Expand Down Expand Up @@ -278,6 +304,10 @@ impl PyFakeBasicFileSystem {
self.inner.read(file_name)
}

fn write(&mut self, file_name: &str, contents: &str) -> PyResult<()> {
self.inner.write(file_name, contents)
}

// Temporary workaround method for Python tests.
fn convert_to_basic(&self) -> PyResult<Self> {
Ok(PyFakeBasicFileSystem {
Expand All @@ -289,7 +319,7 @@ impl PyFakeBasicFileSystem {
/// Parses an indented string representing a file system structure
/// into a HashMap where keys are full file paths.
/// See tests.adaptors.filesystem.FakeFileSystem for the API.
pub fn parse_indented_file_system_string(file_system_string: &str) -> HashMap<String, String> {
fn parse_indented_file_system_string(file_system_string: &str) -> HashMap<String, String> {
let mut file_paths_map: HashMap<String, String> = HashMap::new();
let mut path_stack: Vec<String> = Vec::new(); // Stores current directory path components
let mut first_line = true; // Flag to handle the very first path component
Expand Down Expand Up @@ -381,7 +411,6 @@ pub fn get_file_system_boxed<'py>(
file_system: &Bound<'py, PyAny>,
) -> PyResult<Box<dyn FileSystem + Send + Sync>> {
let file_system_boxed: Box<dyn FileSystem + Send + Sync>;

if let Ok(py_real) = file_system.extract::<PyRef<PyRealBasicFileSystem>>() {
file_system_boxed = Box::new(py_real.inner.clone());
} else if let Ok(py_fake) = file_system.extract::<PyRef<PyFakeBasicFileSystem>>() {
Expand All @@ -391,5 +420,6 @@ pub fn get_file_system_boxed<'py>(
"file_system must be an instance of RealBasicFileSystem or FakeBasicFileSystem",
));
}

Ok(file_system_boxed)
}
24 changes: 20 additions & 4 deletions rust/src/import_scanning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,23 @@ pub struct DirectImport {
pub line_contents: String,
}

pub fn py_found_packages_to_rust(py_found_packages: &Bound<'_, PyAny>) -> HashSet<FoundPackage> {
impl<'py> FromPyObject<'py> for DirectImport {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let importer: String = ob.getattr("importer")?.getattr("name")?.extract()?;
let imported: String = ob.getattr("imported")?.getattr("name")?.extract()?;
let line_number: usize = ob.getattr("line_number")?.extract()?;
let line_contents: String = ob.getattr("line_contents")?.extract()?;

Ok(DirectImport {
importer,
imported,
line_number,
line_contents,
})
}
}

fn py_found_packages_to_rust(py_found_packages: &Bound<'_, PyAny>) -> HashSet<FoundPackage> {
let py_set = py_found_packages
.downcast::<PySet>()
.expect("Expected py_found_packages to be a Python set.");
Expand All @@ -36,7 +52,7 @@ pub fn py_found_packages_to_rust(py_found_packages: &Bound<'_, PyAny>) -> HashSe
rust_found_packages
}

pub fn get_modules_from_found_packages(found_packages: &HashSet<FoundPackage>) -> HashSet<Module> {
fn get_modules_from_found_packages(found_packages: &HashSet<FoundPackage>) -> HashSet<Module> {
let mut modules = HashSet::new();
for package in found_packages {
for module_file in &package.module_files {
Expand All @@ -57,7 +73,7 @@ fn module_is_descendant(module_name: &str, potential_ancestor: &str) -> bool {
/// Statically analyses the given module and returns a set of Modules that
/// it imports.
#[allow(clippy::borrowed_box)]
pub fn scan_for_imports_no_py(
fn scan_for_imports_no_py(
file_system: &Box<dyn FileSystem + Send + Sync>,
found_packages: &HashSet<FoundPackage>,
include_external_packages: bool,
Expand Down Expand Up @@ -153,7 +169,7 @@ fn scan_for_imports_no_py_single_module(
Ok(imports)
}

pub fn to_py_direct_imports<'a>(
fn to_py_direct_imports<'a>(
py: Python<'a>,
rust_imports: &HashSet<DirectImport>,
) -> Bound<'a, PySet> {
Expand Down
3 changes: 3 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ mod _rustgrimp {
#[pymodule_export]
use crate::caching::read_cache_data_map_file;

#[pymodule_export]
use crate::caching::write_cache_data_map_file;

#[pymodule_export]
use crate::graph::GraphWrapper;

Expand Down
29 changes: 9 additions & 20 deletions src/grimp/adaptors/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from typing import Optional

from grimp.application.ports.filesystem import AbstractFileSystem
from grimp.application.ports.filesystem import BasicFileSystem
from grimp.application.ports.modulefinder import FoundPackage, ModuleFile
from grimp.domain.valueobjects import DirectImport, Module

Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(self, *args, namer: type[CacheFileNamer], **kwargs) -> None:
@classmethod
def setup(
cls,
file_system: AbstractFileSystem,
file_system: BasicFileSystem,
found_packages: set[FoundPackage],
include_external_packages: bool,
exclude_type_checking_imports: bool = False,
Expand Down Expand Up @@ -122,22 +122,6 @@ def write(
) -> None:
self._write_marker_files_if_not_already_there()
# Write data file.
primitives_map: PrimitiveFormat = {}
for found_package in self.found_packages:
primitives_map_for_found_package: PrimitiveFormat = {
module_file.module.name: [
(
direct_import.imported.name,
direct_import.line_number,
direct_import.line_contents,
)
for direct_import in imports_by_module[module_file.module]
]
for module_file in found_package.module_files
}
primitives_map.update(primitives_map_for_found_package)

serialized = json.dumps(primitives_map)
data_cache_filename = self.file_system.join(
self.cache_dir,
self._namer.make_data_file_name(
Expand All @@ -146,7 +130,12 @@ def write(
exclude_type_checking_imports=self.exclude_type_checking_imports,
),
)
self.file_system.write(data_cache_filename, serialized)
rust.write_cache_data_map_file(
filename=data_cache_filename,
imports_by_module=imports_by_module,
file_system=self.file_system,
)

logger.info(f"Wrote data cache file {data_cache_filename}.")

# Write meta files.
Expand Down Expand Up @@ -202,7 +191,7 @@ def _read_data_map_file(self) -> dict[Module, set[DirectImport]]:
)
try:
imports_by_module = rust.read_cache_data_map_file(
data_cache_filename, self.file_system.convert_to_basic()
data_cache_filename, self.file_system
)
except FileNotFoundError:
logger.info(f"No cache file: {data_cache_filename}.")
Expand Down
6 changes: 3 additions & 3 deletions src/grimp/application/ports/caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from grimp.application.ports.modulefinder import FoundPackage, ModuleFile
from grimp.domain.valueobjects import DirectImport, Module

from .filesystem import AbstractFileSystem
from .filesystem import BasicFileSystem


class CacheMiss(Exception):
Expand All @@ -11,7 +11,7 @@ class CacheMiss(Exception):
class Cache:
def __init__(
self,
file_system: AbstractFileSystem,
file_system: BasicFileSystem,
include_external_packages: bool,
exclude_type_checking_imports: bool,
found_packages: set[FoundPackage],
Expand All @@ -29,7 +29,7 @@ def __init__(
@classmethod
def setup(
cls,
file_system: AbstractFileSystem,
file_system: BasicFileSystem,
found_packages: set[FoundPackage],
*,
include_external_packages: bool,
Expand Down
Loading