Skip to content

Commit 0d5292d

Browse files
perf(tidy3d): FXC-3721 Speed up test suite
1 parent 2ce5542 commit 0d5292d

File tree

15 files changed

+433
-74
lines changed

15 files changed

+433
-74
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,5 +135,6 @@ htmlcov/
135135
.idea
136136
.vscode
137137

138-
# cProfile output
138+
# profile outputs
139139
*.prof
140+
pytest_profile.txt

docs/development/usage.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ There are a range of handy development functions that you might want to use to s
6767
* - Running ``pytest`` commands inside the ``poetry`` environment.
6868
- Make sure you have already installed ``tidy3d`` in ``poetry`` and you are in the root directory.
6969
- ``poetry run pytest``
70+
* - Analyze slow ``pytest`` runs with durations / cProfile / debug subset helpers.
71+
- Use ``--debug`` to run only the first N collected tests or ``--profile`` to capture call stacks.
72+
- ``python scripts/profile_pytest.py [options]``
7073
* - Run ``coverage`` testing from the ``poetry`` environment.
7174
-
7275
- ``poetry run coverage run -m pytest``
@@ -84,4 +87,3 @@ There are a range of handy development functions that you might want to use to s
8487
- ``poetry run tidy3d develop replace-in-files``
8588

8689

87-

poetry.lock

Lines changed: 18 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pytest-timeout = { version = "*", optional = true }
6565
pytest-xdist = "^3.6.1"
6666
pytest-cov = "^6.0.0"
6767
pytest-env = "^1.1.5"
68+
pytest-order = { version = "^1.2.1", optional = true }
6869
tox = { version = "*", optional = true }
6970
diff-cover = { version = "*", optional = true }
7071
zizmor = { version = "*", optional = true }
@@ -154,6 +155,7 @@ dev = [
154155
'pytest-xdist',
155156
'pytest-env',
156157
'pytest-cov',
158+
'pytest-order',
157159
'rtree',
158160
'ruff',
159161
'sax',

scripts/profile_pytest.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
#!/usr/bin/env python3
2+
"""Helper utilities for profiling ``pytest`` runs inside the Poetry env.
3+
4+
This script can:
5+
* run the full test suite (default) while surfacing the slowest tests via ``--durations``;
6+
* run in "debug" mode to execute only the first N collected tests; and
7+
* wrap ``pytest`` in ``cProfile`` to identify the most expensive function calls.
8+
9+
Examples::
10+
11+
python scripts/profile_pytest.py # full suite with slowest 25 tests listed
12+
python scripts/profile_pytest.py --debug --debug-limit 10
13+
python scripts/profile_pytest.py --profile --profile-output results.prof
14+
python scripts/profile_pytest.py -t tests/test_components/test_scene.py \
15+
--pytest-args "-k basic"
16+
17+
Forward any additional `pytest` CLI flags via ``--pytest-args"...`` and provide
18+
explicit test targets with ``-t/--tests`` (defaults to the entire ``tests`` dir).
19+
"""
20+
21+
from __future__ import annotations
22+
23+
import argparse
24+
import re
25+
import shlex
26+
import shutil
27+
import subprocess
28+
import sys
29+
from collections import defaultdict
30+
from collections.abc import Iterable
31+
from pathlib import Path
32+
33+
try:
34+
import pstats
35+
except ImportError as exc: # pragma: no cover - stdlib module should exist
36+
raise SystemExit("pstats from the standard library is required") from exc
37+
38+
DURATION_LINE_RE = re.compile(r"^\s*(?P<secs>\d+(?:\.\d+)?)s\s+\w+\s+(?P<nodeid>\S+)\s*$")
39+
40+
41+
def parse_args() -> argparse.Namespace:
42+
parser = argparse.ArgumentParser(
43+
description="Profile pytest executions launched via Poetry.",
44+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
45+
)
46+
parser.add_argument(
47+
"--debug",
48+
action="store_true",
49+
help="Run only a subset of collected tests (see --debug-limit).",
50+
)
51+
parser.add_argument(
52+
"--list-limit",
53+
type=int,
54+
default=30,
55+
help="How many entries to show in aggregated duration summaries (set 0 for all).",
56+
)
57+
parser.add_argument(
58+
"--debug-limit",
59+
type=int,
60+
default=25,
61+
help="Number of test node ids to execute when --debug is enabled.",
62+
)
63+
parser.add_argument(
64+
"--durations",
65+
type=int,
66+
default=0,
67+
help="Pass-through value for pytest's --durations flag (use 0 for all tests).",
68+
)
69+
parser.add_argument(
70+
"--profile",
71+
action="store_true",
72+
help="Wrap pytest in cProfile and display the heaviest call sites afterward.",
73+
)
74+
parser.add_argument(
75+
"--profile-output",
76+
default="results.prof",
77+
help="Where to write the binary cProfile stats (used when --profile is set).",
78+
)
79+
parser.add_argument(
80+
"--profile-top",
81+
type=int,
82+
default=30,
83+
help="How many rows of aggregated profile data to print.",
84+
)
85+
parser.add_argument(
86+
"--profile-sort",
87+
choices=["cumulative", "tottime", "calls", "time"],
88+
default="cumulative",
89+
help="Sort order for the profile summary table.",
90+
)
91+
parser.add_argument(
92+
"-t",
93+
"--tests",
94+
action="append",
95+
dest="tests",
96+
metavar="PATH_OR_NODE",
97+
help="Explicit pytest targets. Repeatable.",
98+
)
99+
parser.add_argument(
100+
"--pytest-args",
101+
default="",
102+
help="Extra pytest CLI args as a quoted string (e.g. '--maxfail=1 -k smoke').",
103+
)
104+
return parser.parse_args()
105+
106+
107+
def ensure_poetry_available() -> None:
108+
if shutil.which("poetry") is None:
109+
raise SystemExit("'poetry' command not found in PATH.")
110+
111+
112+
def build_pytest_base(profile: bool, profile_output: Path) -> list[str]:
113+
base_cmd = ["poetry", "run"]
114+
if profile:
115+
base_cmd += [
116+
"python",
117+
"-m",
118+
"cProfile",
119+
"-o",
120+
str(profile_output.resolve()),
121+
"-m",
122+
"pytest",
123+
]
124+
else:
125+
base_cmd.append("pytest")
126+
return base_cmd
127+
128+
129+
def collect_node_ids(extra_args: Iterable[str], tests: Iterable[str]) -> list[str]:
130+
cmd = ["poetry", "run", "pytest", "--collect-only", "-q"]
131+
cmd.extend(extra_args)
132+
cmd.extend(tests)
133+
print(f"Collecting tests via: {' '.join(shlex.quote(part) for part in cmd)}")
134+
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
135+
sys.stdout.write(result.stdout)
136+
sys.stderr.write(result.stderr)
137+
if result.returncode != 0:
138+
raise SystemExit(result.returncode)
139+
140+
node_ids: list[str] = []
141+
for line in result.stdout.splitlines():
142+
stripped = line.strip()
143+
if not stripped or stripped.startswith(("<", "collected ")):
144+
continue
145+
node_ids.append(stripped)
146+
if not node_ids:
147+
raise SystemExit("No tests collected; check your --tests / --pytest-args filters.")
148+
return node_ids
149+
150+
151+
def summarize_profile(stats_path: Path, sort: str, top: int) -> None:
152+
if not stats_path.exists():
153+
print(f"Profile file {stats_path} not found; skipping summary.")
154+
return
155+
stats = pstats.Stats(str(stats_path))
156+
stats.sort_stats(sort)
157+
print("\nTop profiled call sites (via cProfile):")
158+
stats.print_stats(top)
159+
160+
161+
def extract_durations_from_output(output: str) -> list[tuple[float, str]]:
162+
"""Parse pytest --durations lines from stdout."""
163+
164+
durations: list[tuple[float, str]] = []
165+
for line in output.splitlines():
166+
match = DURATION_LINE_RE.match(line)
167+
if not match:
168+
continue
169+
secs = float(match.group("secs"))
170+
nodeid = match.group("nodeid")
171+
durations.append((secs, nodeid))
172+
return durations
173+
174+
175+
def print_aggregated_durations(
176+
durations: list[tuple[float, str]],
177+
list_limit: int,
178+
) -> None:
179+
"""Print durations aggregated by file and by test (collapsing parametrizations)."""
180+
181+
if not durations:
182+
print("\n[durations] no --durations lines found in pytest output.")
183+
return
184+
185+
by_file: dict[str, float] = defaultdict(float)
186+
by_test: dict[str, float] = defaultdict(float)
187+
188+
for secs, nodeid in durations:
189+
base = nodeid.split("[", 1)[0]
190+
file_name = base.split("::", 1)[0]
191+
by_file[file_name] += secs
192+
by_test[base] += secs
193+
194+
def _print_section(title: str, mapping: dict[str, float]) -> None:
195+
print(f"\nAggregated durations ({title}):")
196+
items = sorted(mapping.items(), key=lambda kv: kv[1], reverse=True)
197+
if list_limit > 0:
198+
items = items[:list_limit]
199+
for name, total in items:
200+
print(f"{total:8.02f}s {name}")
201+
202+
_print_section("by file", by_file)
203+
_print_section("by test (parametrizations combined)", by_test)
204+
205+
206+
def truncate_pytest_durations_output(output: str, limit: int) -> str:
207+
"""Keep pytest's duration section header, but show only the top `limit` lines."""
208+
lines = output.splitlines()
209+
out_lines = []
210+
in_durations_section = False
211+
kept = 0
212+
213+
for line in lines:
214+
if "slowest" in line and "durations" in line:
215+
in_durations_section = True
216+
kept = 0
217+
out_lines.append(line)
218+
continue
219+
220+
if in_durations_section:
221+
# Stop after we've shown N durations or reached next blank section
222+
if not line.strip():
223+
in_durations_section = False
224+
elif kept >= limit:
225+
continue
226+
else:
227+
kept += 1
228+
229+
out_lines.append(line)
230+
return "\n".join(out_lines)
231+
232+
233+
def export_to_file(result, args, filtered_stdout, durations):
234+
sys.stdout.write(filtered_stdout)
235+
sys.stderr.write(result.stderr)
236+
237+
# Write the filtered output to a file as well
238+
results_path = Path("pytest_profile_stats.txt")
239+
results_path.write_text(filtered_stdout)
240+
241+
if durations:
242+
print_aggregated_durations(durations, args.list_limit)
243+
244+
with results_path.open("a") as f:
245+
f.write("\n\n[Aggregated Durations]\n")
246+
for secs, nodeid in durations:
247+
f.write(f"{secs:.2f}s {nodeid}\n")
248+
249+
250+
def main() -> int:
251+
args = parse_args()
252+
ensure_poetry_available()
253+
254+
if args.debug and args.debug_limit <= 0:
255+
raise SystemExit("--debug-limit must be a positive integer.")
256+
257+
tests = args.tests or ["tests"]
258+
extra_args = shlex.split(args.pytest_args)
259+
260+
# Handle debug collection (collect-only)
261+
if args.debug:
262+
collected = collect_node_ids(extra_args, tests)
263+
pytest_targets = collected[: args.debug_limit]
264+
print(f"\nDebug mode: running the first {len(pytest_targets)} collected test(s).")
265+
else:
266+
pytest_targets = tests
267+
268+
# Build the full pytest command
269+
base_cmd = build_pytest_base(args.profile, Path(args.profile_output))
270+
pytest_cmd = base_cmd + extra_args
271+
if args.durations is not None:
272+
pytest_cmd.append(f"--durations={args.durations}")
273+
pytest_cmd.extend(pytest_targets)
274+
275+
print(f"\nExecuting: {' '.join(shlex.quote(part) for part in pytest_cmd)}\n")
276+
277+
# Run pytest
278+
result = subprocess.run(
279+
pytest_cmd,
280+
check=False,
281+
text=True,
282+
capture_output=True,
283+
)
284+
285+
# Extract and truncate outputs
286+
filtered_stdout = truncate_pytest_durations_output(result.stdout, args.list_limit)
287+
durations = extract_durations_from_output(result.stdout) if args.durations is not None else []
288+
289+
# Print once and export
290+
export_to_file(result, args, filtered_stdout, durations)
291+
292+
# Profile summary (if enabled)
293+
if args.profile and result.returncode == 0:
294+
summarize_profile(Path(args.profile_output), args.profile_sort, args.profile_top)
295+
296+
return result.returncode
297+
298+
299+
if __name__ == "__main__":
300+
raise SystemExit(main())

0 commit comments

Comments
 (0)