diff --git a/vulnerabilities/importers/rust.py b/vulnerabilities/importers/rust.py index c61907a82..deb8bf24a 100644 --- a/vulnerabilities/importers/rust.py +++ b/vulnerabilities/importers/rust.py @@ -8,7 +8,9 @@ # import asyncio +import logging from itertools import chain +from typing import Iterable from typing import List from typing import Optional from typing import Set @@ -27,8 +29,19 @@ from vulnerabilities.package_managers import CratesVersionAPI from vulnerabilities.utils import nearest_patched_package +logger = logging.getLogger(__name__) + class RustImporter(Importer): + def __init__(self, purl=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.purl = purl + if self.purl: + if self.purl.type != "cargo": + print( + f"Warning: PURL type {self.purl.type} is not 'cargo', may not match any advisories" + ) + def __enter__(self): super(RustImporter, self).__enter__() @@ -49,12 +62,19 @@ def set_api(self, packages): asyncio.run(self.crates_api.load_api(packages)) def updated_advisories(self) -> Set[AdvisoryData]: - return self._load_advisories(self._updated_files.union(self._added_files)) + if not self.purl: + return self._load_advisories(self._updated_files.union(self._added_files)) + + return self._load_advisories_for_package(self.purl.name) def _load_advisories(self, files) -> Set[AdvisoryData]: # per @tarcieri It will always be named RUSTSEC-0000-0000.md # https://github.com/nexB/vulnerablecode/pull/281/files#r528899864 files = [f for f in files if not f.endswith("-0000.md")] # skip temporary files + if self.purl: + files = [f for f in files if f"crates/{self.purl.name}/" in f] + if not files: + return [] packages = self.collect_packages(files) self.set_api(packages) @@ -64,6 +84,12 @@ def _load_advisories(self, files) -> Set[AdvisoryData]: for path in batch: advisory = self._load_advisory(path) if advisory: + if ( + self.purl + and self.purl.version + and not self._advisory_affects_version(advisory) + ): + continue advisories.append(advisory) yield advisories @@ -133,6 +159,42 @@ def _load_advisory(self, path: str) -> Optional[AdvisoryData]: references=references, ) + def _advisory_affects_version(self, advisory: AdvisoryData) -> bool: + if not self.purl.version: + return True + + version = SemverVersion(self.purl.version) + for affected_package in advisory.affected_packages: + if affected_package.package.name == self.purl.name: + if ( + affected_package.affected_version_range + and version in affected_package.affected_version_range + ): + return True + + return False + + def _load_advisories_for_package(self, package_name) -> Iterable[AdvisoryData]: + files = [ + f + for f in self._added_files.union(self._updated_files) + if f"crates/{package_name}/" in f and f.endswith(".md") and not f.endswith("-0000.md") + ] + + if not files: + logger.info(f"No advisories found for {package_name} in Rust advisory database") + return + + self.set_api([package_name]) + + for path in files: + advisory = self._load_advisory(path) + if advisory: + # If version is specified in PURL, check if it's in the affected versions + if self.purl.version and not self._advisory_affects_version(advisory): + continue + yield advisory + def categorize_versions( all_versions: Set[str], diff --git a/vulnerabilities/tests/test_rust.py b/vulnerabilities/tests/test_rust.py index 58b7c4302..97b54d8cc 100644 --- a/vulnerabilities/tests/test_rust.py +++ b/vulnerabilities/tests/test_rust.py @@ -10,8 +10,10 @@ import os from unittest import TestCase +import pytest from packageurl import PackageURL from univers.version_range import VersionRange +from univers.versions import SemverVersion from vulnerabilities.importer import AdvisoryData from vulnerabilities.importer import Reference @@ -183,3 +185,46 @@ def test_load_toml_from_md(self): } assert loaded_data == expected_data + + +@pytest.fixture +def rust_importer_with_mock(monkeypatch): + class DummyVCSResponse: + repo_dirs = [os.path.join(TEST_DATA, "..", "test_data", "rust")] + + importer = RustImporter() + importer._crates_api = MOCKED_CRATES_API_VERSIONS + importer.vcs_response = DummyVCSResponse() + return importer + + +def test_rust_importer_package_first_affecting(rust_importer_with_mock): + purl = PackageURL(type="cargo", name="byte_struct") + importer = rust_importer_with_mock + importer.purl = purl + advisories = list(importer._load_advisories_for_package("byte_struct")) + assert len(advisories) == 1 + assert any(ap.package.name == "byte_struct" for ap in advisories[0].affected_packages) + + +def test_rust_importer_package_first_version_affecting(rust_importer_with_mock): + purl = PackageURL(type="cargo", name="byte_struct", version="0.6.0") + importer = rust_importer_with_mock + importer.purl = purl + advisories = list(importer._load_advisories_for_package("byte_struct")) + + assert len(advisories) == 1 + found = False + for ap in advisories[0].affected_packages: + if ap.package.name == "byte_struct": + if ap.affected_version_range and SemverVersion("0.6.0") in ap.affected_version_range: + found = True + assert found + + +def test_rust_importer_package_first_not_found(rust_importer_with_mock): + purl = PackageURL(type="cargo", name="nonexistent") + importer = rust_importer_with_mock + importer.purl = purl + advisories = list(importer._load_advisories_for_package(purl.name)) + assert advisories == []