Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 70 additions & 14 deletions codex-rs/exec-server/src/posix/escalate_client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::io;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd as _;
Expand Down Expand Up @@ -34,21 +35,15 @@ pub(crate) async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32>
.send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()])
.await
.context("failed to send handshake datagram")?;
let env = std::env::vars()
.filter(|(k, _)| {
!matches!(
k.as_str(),
ESCALATE_SOCKET_ENV_VAR | BASH_EXEC_WRAPPER_ENV_VAR
)
})
.collect();
let env = filter_env(std::env::vars());
let request = EscalateRequest {
file: file.clone().into(),
argv: argv.clone(),
workdir: std::env::current_dir()?,
env,
};
client
.send(EscalateRequest {
file: file.clone().into(),
argv: argv.clone(),
workdir: std::env::current_dir()?,
env,
})
.send(request)
.await
.context("failed to send EscalateRequest")?;
let message = client.receive::<EscalateResponse>().await?;
Expand Down Expand Up @@ -107,3 +102,64 @@ pub(crate) async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32>
}
}
}

fn filter_env<I>(env_iter: I) -> HashMap<String, String>
where
I: IntoIterator<Item = (String, String)>,
{
const MAX_ENV_ENTRY_LEN: i64 = 8_192;
let mut env = HashMap::new();
for (key, value) in env_iter {
if matches!(
key.as_str(),
ESCALATE_SOCKET_ENV_VAR | BASH_EXEC_WRAPPER_ENV_VAR
) {
continue;
}
let entry_len = (key.len() + value.len()) as i64;
if entry_len > MAX_ENV_ENTRY_LEN {
tracing::debug!(key, entry_len, "skipping oversized environment variable");
continue;
}
env.insert(key, value);
}
env
}

#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;

#[test]
fn filter_env_drops_oversized_and_reserved_entries() {
let oversized_value = "A".repeat(8_193);
let env = vec![
("KEEP".to_string(), "ok".to_string()),
("DROP".to_string(), oversized_value),
(
ESCALATE_SOCKET_ENV_VAR.to_string(),
"should_skip".to_string(),
),
(
BASH_EXEC_WRAPPER_ENV_VAR.to_string(),
"should_skip".to_string(),
),
];
let filtered = filter_env(env);
assert_eq!(Some(&"ok".to_string()), filtered.get("KEEP"));
assert!(!filtered.contains_key("DROP"));
assert!(!filtered.contains_key(ESCALATE_SOCKET_ENV_VAR));
assert!(!filtered.contains_key(BASH_EXEC_WRAPPER_ENV_VAR));
}

#[test]
fn filter_env_keeps_entries_at_limit() {
const KEY: &str = "KEEP";
let value_len = 8_192 - KEY.len();
let env = vec![(KEY.to_string(), "A".repeat(value_len))];
let filtered = filter_env(env);
assert_eq!(1, filtered.len());
assert_eq!(value_len, filtered[KEY].len());
}
}
59 changes: 42 additions & 17 deletions codex-rs/exec-server/src/posix/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,26 @@ fn send_message_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io:
frame.extend_from_slice(&encode_length(data.len())?);
frame.extend_from_slice(data);

let mut control = vec![0u8; control_space_for_fds(fds.len())];
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len = libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
for (i, fd) in fds.iter().enumerate() {
data_ptr.add(i).write(fd.as_raw_fd());
}
}

let mut control;
let payload = [IoSlice::new(&frame)];
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
let mut sent = socket.sendmsg(&msg, 0)?;
let mut sent = if fds.is_empty() {
socket.send(&frame)?
} else {
control = vec![0u8; control_space_for_fds(fds.len())];
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len =
libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
for (i, fd) in fds.iter().enumerate() {
data_ptr.add(i).write(fd.as_raw_fd());
}
}
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
socket.sendmsg(&msg, 0)?
};
while sent < frame.len() {
let bytes = socket.send(&frame[sent..])?;
if bytes == 0 {
Expand Down Expand Up @@ -236,8 +241,9 @@ fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io
format!("too many fds: {}", fds.len()),
));
}
let mut control = vec![0u8; control_space_for_fds(fds.len())];
if !fds.is_empty() {

let control = if !fds.is_empty() {
let mut control = vec![0u8; control_space_for_fds(fds.len())];
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len =
Expand All @@ -249,7 +255,10 @@ fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io
data_ptr.add(i).write(fd.as_raw_fd());
}
}
}
control
} else {
vec![]
};
let payload = [IoSlice::new(data)];
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
let written = socket.sendmsg(&msg, 0)?;
Expand Down Expand Up @@ -433,6 +442,22 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn async_socket_round_trips_without_fds() -> std::io::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let payload = TestPayload {
id: 13,
label: "no-fds".to_string(),
};

let receive_task = tokio::spawn(async move { server.receive::<TestPayload>().await });
client.send(payload.clone()).await?;

let received_payload = receive_task.await.unwrap()?;
assert_eq!(payload, received_payload);
Ok(())
}

#[tokio::test]
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
let (server, client) = AsyncDatagramSocket::pair()?;
Expand Down
Loading