Skip to content

Commit 4db2fde

Browse files
ahoppenfabianfett
andauthored
Implement COPY … FROM STDIN queries (#566)
Co-authored-by: Fabian Fett <fabianfett@apple.com>
1 parent 78114d4 commit 4db2fde

File tree

10 files changed

+1343
-72
lines changed

10 files changed

+1343
-72
lines changed
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
/// Handle to send data for a `COPY ... FROM STDIN` query to the backend.
2+
public struct PostgresCopyFromWriter: Sendable {
3+
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
4+
private let eventLoop: any EventLoop
5+
6+
init(handler: PostgresChannelHandler, eventLoop: any EventLoop) {
7+
self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop)
8+
self.eventLoop = eventLoop
9+
}
10+
11+
private func writeAssumingInEventLoop(_ byteBuffer: ByteBuffer, _ continuation: CheckedContinuation<Void, any Error>) {
12+
precondition(eventLoop.inEventLoop)
13+
let promise = eventLoop.makePromise(of: Void.self)
14+
self.channelHandler.value.checkBackendCanReceiveCopyData(promise: promise)
15+
promise.futureResult.flatMap {
16+
if self.eventLoop.inEventLoop {
17+
return self.eventLoop.makeCompletedFuture(withResultOf: {
18+
try self.channelHandler.value.sendCopyData(byteBuffer)
19+
})
20+
} else {
21+
let promise = self.eventLoop.makePromise(of: Void.self)
22+
self.eventLoop.execute {
23+
promise.completeWith(Result(catching: { try self.channelHandler.value.sendCopyData(byteBuffer) }))
24+
}
25+
return promise.futureResult
26+
}
27+
}.whenComplete { result in
28+
continuation.resume(with: result)
29+
}
30+
}
31+
32+
/// Send data for a `COPY ... FROM STDIN` operation to the backend.
33+
///
34+
/// - Throws: If an error occurs during the write of if the backend sent an `ErrorResponse` during the copy
35+
/// operation, eg. to indicate that a **previous** `write` call had an invalid format.
36+
public func write(_ byteBuffer: ByteBuffer) async throws {
37+
// Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the
38+
// `writeData` closure. It is likely that the user would forget to do so.
39+
try Task.checkCancellation()
40+
41+
try await withTaskCancellationHandler {
42+
do {
43+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
44+
if self.eventLoop.inEventLoop {
45+
writeAssumingInEventLoop(byteBuffer, continuation)
46+
} else {
47+
self.eventLoop.execute {
48+
writeAssumingInEventLoop(byteBuffer, continuation)
49+
}
50+
}
51+
}
52+
} catch {
53+
if Task.isCancelled {
54+
// If the task was cancelled, we might receive a postgres error which is an artifact about how we
55+
// communicate the cancellation to the state machine. Throw a `CancellationError` to the user
56+
// instead, which looks more like native Swift Concurrency code.
57+
throw CancellationError()
58+
}
59+
throw error
60+
}
61+
} onCancel: {
62+
if self.eventLoop.inEventLoop {
63+
self.channelHandler.value.cancel()
64+
} else {
65+
self.eventLoop.execute {
66+
self.channelHandler.value.cancel()
67+
}
68+
}
69+
}
70+
}
71+
72+
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
73+
/// the backend.
74+
func done() async throws {
75+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
76+
if self.eventLoop.inEventLoop {
77+
self.channelHandler.value.sendCopyDone(continuation: continuation)
78+
} else {
79+
self.eventLoop.execute {
80+
self.channelHandler.value.sendCopyDone(continuation: continuation)
81+
}
82+
}
83+
}
84+
}
85+
86+
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to
87+
/// the backend.
88+
func failed(error: any Error) async throws {
89+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
90+
if self.eventLoop.inEventLoop {
91+
self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation)
92+
} else {
93+
self.eventLoop.execute {
94+
self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation)
95+
}
96+
}
97+
}
98+
}
99+
}
100+
101+
/// Specifies the format in which data is transferred to the backend in a COPY operation.
102+
///
103+
/// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings
104+
/// and their default values.
105+
public struct PostgresCopyFromFormat: Sendable {
106+
/// Options that can be used to modify the `text` format of a COPY operation.
107+
public struct TextOptions: Sendable {
108+
/// The delimiter that separates columns in the data.
109+
///
110+
/// See the `DELIMITER` option in Postgres's `COPY` command.
111+
public var delimiter: UnicodeScalar? = nil
112+
113+
public init() {}
114+
}
115+
116+
enum Format {
117+
case text(TextOptions)
118+
}
119+
120+
var format: Format
121+
122+
public static func text(_ options: TextOptions) -> PostgresCopyFromFormat {
123+
return PostgresCopyFromFormat(format: .text(options))
124+
}
125+
}
126+
127+
/// Create a `COPY ... FROM STDIN` query based on the given parameters.
128+
///
129+
/// An empty `columns` array signifies that no columns should be specified in the query and that all columns will be
130+
/// copied by the caller.
131+
///
132+
/// - Warning: The table and column names are inserted into the `COPY FROM` query as passed and might thus be
133+
/// susceptible to SQL injection. Ensure no untrusted data is contained in these strings.
134+
private func buildCopyFromQuery(
135+
table: String,
136+
columns: [String] = [],
137+
format: PostgresCopyFromFormat
138+
) -> PostgresQuery {
139+
var query = """
140+
COPY "\(table)"
141+
"""
142+
if !columns.isEmpty {
143+
query += "("
144+
query += columns.map { #""\#($0)""# }.joined(separator: ",")
145+
query += ")"
146+
}
147+
query += " FROM STDIN"
148+
var queryOptions: [String] = []
149+
switch format.format {
150+
case .text(let options):
151+
queryOptions.append("FORMAT text")
152+
if let delimiter = options.delimiter {
153+
// Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection.
154+
queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'")
155+
}
156+
}
157+
precondition(!queryOptions.isEmpty)
158+
query += " WITH ("
159+
query += queryOptions.map { "\($0)" }.joined(separator: ",")
160+
query += ")"
161+
return "\(unescaped: query)"
162+
}
163+
164+
extension PostgresConnection {
165+
/// Copy data into a table using a `COPY <table name> FROM STDIN` query.
166+
///
167+
/// - Parameters:
168+
/// - table: The name of the table into which to copy the data.
169+
/// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied.
170+
/// - format: Options that specify the format of the data that is produced by `writeData`.
171+
/// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the
172+
/// writer provided by the closure to send data to the backend and return from the closure once all data is sent.
173+
/// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown
174+
/// by the `copyFrom` function.
175+
///
176+
/// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be
177+
/// susceptible to SQL injection. Ensure no untrusted data is contained in these strings.
178+
public func copyFrom(
179+
table: String,
180+
columns: [String] = [],
181+
format: PostgresCopyFromFormat = .text(.init()),
182+
logger: Logger,
183+
file: String = #fileID,
184+
line: Int = #line,
185+
writeData: (PostgresCopyFromWriter) async throws -> Void
186+
) async throws {
187+
var logger = logger
188+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
189+
let writer: PostgresCopyFromWriter = try await withCheckedThrowingContinuation { continuation in
190+
let context = ExtendedQueryContext(
191+
copyFromQuery: buildCopyFromQuery(table: table, columns: columns, format: format),
192+
triggerCopy: continuation,
193+
logger: logger
194+
)
195+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
196+
}
197+
198+
do {
199+
try await writeData(writer)
200+
} catch {
201+
// We need to send a `CopyFail` to the backend to put it out of copy mode. This will most likely throw, most
202+
// notably for the following two reasons. In both of them, it's better to ignore the error thrown by
203+
// `writer.failed` and instead throw the error from `writeData`:
204+
// - We send `CopyFail` and the backend replies with an `ErrorResponse` that relays the `CopyFail` message.
205+
// This took the backend out of copy mode but it's more informative to the user to see the error they
206+
// threw instead of the one that got relayed back, so it's better to ignore the error here.
207+
// - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts
208+
// the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger
209+
// a `Sync` that takes the backend out of copy mode. If `writeData` threw the error from from the
210+
// `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it doesn't
211+
// matter that we ignore the error here. If the user threw some other error, it's better to honor the
212+
// user's error.
213+
try? await writer.failed(error: error)
214+
215+
throw error
216+
}
217+
218+
// `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during
219+
// the transfer of the last bit of data so that the user didn't call `PostgresCopyFromWriter.write` again, which
220+
// would have checked the error state. In either of these cases, calling `writer.done` puts the backend out of
221+
// copy mode, so we don't need to send another `CopyFail`. Thus, this must not be handled in the `do` block
222+
// above.
223+
try await writer.done()
224+
}
225+
}

0 commit comments

Comments
 (0)