Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions .github/workflows/code-quality.yml
Original file line number Diff line number Diff line change
@@ -1,49 +1,65 @@
name: Code quality

on:
pull_request:
types: [opened, synchronize, reopened]

permissions:
contents: read

jobs:
lockfile:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- uses: actions/checkout@v5
- uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- run: uv lock --locked
lint:
runs-on: ubuntu-latest
needs: [lockfile]
needs: lockfile
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- uses: actions/checkout@v5
- uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- run: uv run ruff check
format:
runs-on: ubuntu-latest
needs: [lockfile]
needs: lockfile
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- uses: actions/checkout@v5
- uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- run: uv sync --locked --all-extras --dev
- run: uv run ruff format --check
typecheck:
runs-on: ubuntu-latest
needs: [lockfile]
needs: lockfile
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- run: uv sync --extra webapp
- uses: actions/checkout@v5
- uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- run: uv sync --locked --all-extras --dev
- run: uv run mypy
tests:
test:
runs-on: ubuntu-latest
needs: [lockfile]
needs: lockfile
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- run: uv run pytest -v --durations=0
- uses: actions/checkout@v5
- uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- run: uv run scripts/test
build:
runs-on: [ubuntu-latest]
needs: [lint, format, typecheck, tests]
runs-on: ubuntu-latest
needs: [lint, format, typecheck, test]
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- uses: actions/checkout@v5
- uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- run: uv build
26 changes: 17 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
requires = ["uv_build"]
build-backend = "uv_build"


[project]
Expand All @@ -16,7 +16,7 @@ authors = [
{ name = "Soroush Hoseini", email = "soroushhoseini0@gmail.com" },
]
dependencies = [
"numpy<2.0",
"numpy>=2.3.5",
"pyfirmata2>=2.5.0",
"pyzmq>=26.2.0",
"thorlabs-apt-device>=0.3.8",
Expand All @@ -34,11 +34,11 @@ webapp = [
]

[dependency-groups]
dev = ["mypy>=1.13.0", "pytest-cov>=6.0.0", "ruff>=0.8.0"]
dev = ["hypothesis", "mypy", "coverage", "pytest-randomly", "ruff"]


[tool.mypy]
files = ["src"]
files = ["src/**/*.py"]
strict = true
pretty = true

Expand Down Expand Up @@ -76,8 +76,6 @@ extend-ignore = [
"ISC001",
"ISC002",
"E501",
"S101", # Assert
"S311", # We know its not secure
]

[tool.ruff.lint.extend-per-file-ignores]
Expand All @@ -95,5 +93,15 @@ force-single-line = true
[tool.ruff.lint.pydocstyle]
convention = "numpy" # One of: "google" | "numpy" | "pep257"

[tool.uv.sources]
TimeTagger = { path = "/usr/lib/python3/dist-packages" }

[tool.coverage.run]
# source = ["app"]
dynamic_context = "test_function"

[tool.coverage.report]
show_missing = true
skip_empty = true
sort = "-Cover"

[tool.coverage.html]
show_contexts = true
5 changes: 5 additions & 0 deletions scripts/format
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env bash
set -eux

ruff format src
ruff check src --fix
6 changes: 6 additions & 0 deletions scripts/lint
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash
set -eux

mypy
ruff format src --check
ruff check src
6 changes: 6 additions & 0 deletions scripts/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash
set -eux

coverage run -m pytest
coverage report
coverage html --title "${@-coverage}"
15 changes: 13 additions & 2 deletions src/pqnstack/app/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,21 @@
import httpx
from fastapi import Depends

from pqnstack.app.core.config import settings
from pqnstack.network.client import Client


async def get_http_client() -> AsyncGenerator[httpx.AsyncClient, None]:
async with httpx.AsyncClient(timeout=600_000) as client:
async with httpx.AsyncClient(timeout=60) as client:
yield client


type ClientDep = Annotated[httpx.AsyncClient, Depends(get_http_client)]


async def get_instrument_client() -> AsyncGenerator[Client, None]:
async with Client(host=settings.router_address, port=settings.router_port, timeout=60) as client:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a way of specifying the timeout for for different instrument clients. Specially if we are using this to get the timetagger since that timeout is longer than other instruments.

Copy link
Contributor Author

@Benjamin-Nussbaum Benjamin-Nussbaum Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make timeout a parameter of get_device. We could also pre-bake dependencies for particular Instruments if we always want the same timeout for the same Instrument, but keeping it as an argument of get_device is more flexible.

In that case, do we still need to be able to set the timeout for the Client?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timeout is still in ms, if we want to switch to seconds, we need to make that conversion in the client file itself before setting the default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can leave it in ms for now. We can think about if we want to switch later.

yield client


ClientDep = Annotated[httpx.AsyncClient, Depends(get_http_client)]
type InstrumentClientDep = Annotated[httpx.AsyncClient, Depends(get_instrument_client)]
13 changes: 9 additions & 4 deletions src/pqnstack/app/api/routes/chsh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import TYPE_CHECKING
from typing import cast

from fastapi import APIRouter
Expand All @@ -11,6 +12,9 @@
from pqnstack.app.core.models import calculate_chsh_expectation_error
from pqnstack.network.client import Client

if TYPE_CHECKING:
from pqnstack.base.instrument import RotatorInstrument

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/chsh", tags=["chsh"])
Expand All @@ -28,7 +32,7 @@ async def _chsh( # Complexity is high due to the nature of the CHSH experiment.
client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000)

# TODO: Check if settings.chsh_settings.hwp is set before even trying to get the device.
hwp = client.get_device(settings.chsh_settings.hwp[0], settings.chsh_settings.hwp[1])
hwp = cast("RotatorInstrument", client.get_device(settings.chsh_settings.hwp[0], settings.chsh_settings.hwp[1]))
if hwp is None:
logger.error("Could not find half waveplate device")
raise HTTPException(
Expand All @@ -44,7 +48,6 @@ async def _chsh( # Complexity is high due to the nature of the CHSH experiment.
for i in range(2): # Going through follower basis angles
counts = []
for a in [angle, (angle + 90)]:
assert hasattr(hwp, "move_to")
hwp.move_to(a / 2)
for perp in [False, True]:
r = await http_client.post(
Expand Down Expand Up @@ -128,15 +131,17 @@ async def chsh(
@router.post("/request-angle-by-basis")
async def request_angle_by_basis(index: int, *, perp: bool = False) -> bool:
client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000)
hwp = client.get_device(settings.chsh_settings.request_hwp[0], settings.chsh_settings.request_hwp[1])
hwp = cast(
"RotatorInstrument",
client.get_device(settings.chsh_settings.request_hwp[0], settings.chsh_settings.request_hwp[1]),
)
if hwp is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Could not find half waveplate device",
)

angle = state.chsh_request_basis[index] + 90 * perp
assert hasattr(hwp, "move_to")
hwp.move_to(angle / 2)
logger.info("moving waveplate", extra={"angle": angle})
return True
24 changes: 14 additions & 10 deletions src/pqnstack/app/api/routes/qkd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import random
import secrets
from typing import TYPE_CHECKING
from typing import cast

from fastapi import APIRouter
Expand All @@ -13,6 +14,9 @@
from pqnstack.constants import QKDEncodingBasis
from pqnstack.network.client import Client

if TYPE_CHECKING:
from pqnstack.base.instrument import RotatorInstrument

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/qkd", tags=["qkd"])
Expand All @@ -25,7 +29,7 @@ async def _qkd(
) -> list[int]:
logger.debug("Starting QKD")
client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000)
hwp = client.get_device(settings.qkd_settings.hwp[0], settings.qkd_settings.hwp[1])
hwp = cast("RotatorInstrument", client.get_device(settings.qkd_settings.hwp[0], settings.qkd_settings.hwp[1]))

if hwp is None:
logger.error("Could not find half waveplate device")
Expand All @@ -46,10 +50,9 @@ async def _qkd(
)
logger.debug("Handshake with follower successful")

int_choice = random.randint(0, 1) # FIXME: Make this real quantum random.
int_choice = secrets.randbits(1) # FIXME: Make this real quantum random.
logger.debug("Chosen integer choice: %s", int_choice)
state.qkd_bit_list.append(int_choice)
assert hasattr(hwp, "move_to")
hwp.move_to(basis.angles[int_choice].value)
logger.debug("Moving half waveplate to angle: %s", basis.angles[int_choice].value)

Expand Down Expand Up @@ -122,7 +125,10 @@ async def qkd(
@router.post("/single_bit")
async def request_qkd_single_pass() -> bool:
client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000)
hwp = client.get_device(settings.qkd_settings.request_hwp[0], settings.qkd_settings.request_hwp[1])
hwp = cast(
"RotatorInstrument",
client.get_device(settings.qkd_settings.request_hwp[0], settings.qkd_settings.request_hwp[1]),
)

if hwp is None:
logger.error("Could not find half waveplate device")
Expand All @@ -132,12 +138,10 @@ async def request_qkd_single_pass() -> bool:
)

logger.debug("Halfwaveplate device found: %s", hwp)
assert hasattr(hwp, "move_to")

basis_choice = random.choices([QKDEncodingBasis.HV, QKDEncodingBasis.DA])[
0
] # FIXME: Make this real quantum random.
int_choice = random.randint(0, 1) # FIXME: Make this real quantum random.
_bases = (QKDEncodingBasis.HV, QKDEncodingBasis.DA)
basis_choice = _bases[secrets.randbits(1)] # FIXME: Make this real quantum random.
int_choice = secrets.randbits(1) # FIXME: Make this real quantum random.

state.qkd_request_basis_list.append(basis_choice)
state.qkd_request_bit_list.append(int_choice)
Expand Down
16 changes: 11 additions & 5 deletions src/pqnstack/app/api/routes/timetagger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from typing import TYPE_CHECKING
from typing import Annotated
from typing import cast

from fastapi import APIRouter
from fastapi import HTTPException
Expand All @@ -10,6 +12,9 @@
from pqnstack.network.client import Client
from pqnstack.pqn.protocols.measurement import MeasurementConfig

if TYPE_CHECKING:
from pqnstack.base.instrument import TimeTaggerInstrument

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/timetagger", tags=["timetagger"])
Expand All @@ -30,10 +35,13 @@ async def measure_correlation(
)

mconf = MeasurementConfig(
integration_time_s=integration_time_s, binwidth_ps=coincidence_window_ps, channel1=channel1, channel2=channel2
integration_time_s=integration_time_s,
binwidth_ps=coincidence_window_ps,
channel1=channel1,
channel2=channel2,
)
client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000)
tagger = client.get_device(settings.timetagger[0], settings.timetagger[1])
tagger = cast("TimeTaggerInstrument", client.get_device(settings.timetagger[0], settings.timetagger[1]))
if tagger is None:
logger.error("Could not find time tagger device")
raise HTTPException(
Expand All @@ -42,7 +50,6 @@ async def measure_correlation(
)

logger.debug("Time tagger device found: %s", tagger)
assert hasattr(tagger, "measure_correlation")
count = tagger.measure_correlation(
mconf.channel1,
mconf.channel2,
Expand All @@ -67,7 +74,7 @@ async def count_singles(
)

client = Client(host=settings.router_address, port=settings.router_port, timeout=600_000)
tagger = client.get_device(settings.timetagger[0], settings.timetagger[1])
tagger = cast("TimeTaggerInstrument", client.get_device(settings.timetagger[0], settings.timetagger[1]))
if tagger is None:
logger.error("Could not find time tagger device")
raise HTTPException(
Expand All @@ -76,7 +83,6 @@ async def count_singles(
)

logger.debug("Time tagger device found: %s", tagger)
assert hasattr(tagger, "count_singles")
counts = tagger.count_singles(channels, integration_time_s=integration_time_s)

logger.info("Measured singles counts: %s", counts)
Expand Down
Loading