|
| 1 | +/// Handle to send data for a `COPY ... FROM STDIN` query to the backend. |
| 2 | +public struct PostgresCopyFromWriter: Sendable { |
| 3 | + private let channelHandler: NIOLoopBound<PostgresChannelHandler> |
| 4 | + private let eventLoop: any EventLoop |
| 5 | + |
| 6 | + init(handler: PostgresChannelHandler, eventLoop: any EventLoop) { |
| 7 | + self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop) |
| 8 | + self.eventLoop = eventLoop |
| 9 | + } |
| 10 | + |
| 11 | + private func writeAssumingInEventLoop(_ byteBuffer: ByteBuffer, _ continuation: CheckedContinuation<Void, any Error>) { |
| 12 | + precondition(eventLoop.inEventLoop) |
| 13 | + let promise = eventLoop.makePromise(of: Void.self) |
| 14 | + self.channelHandler.value.checkBackendCanReceiveCopyData(promise: promise) |
| 15 | + promise.futureResult.flatMap { |
| 16 | + if self.eventLoop.inEventLoop { |
| 17 | + return self.eventLoop.makeCompletedFuture(withResultOf: { |
| 18 | + try self.channelHandler.value.sendCopyData(byteBuffer) |
| 19 | + }) |
| 20 | + } else { |
| 21 | + let promise = self.eventLoop.makePromise(of: Void.self) |
| 22 | + self.eventLoop.execute { |
| 23 | + promise.completeWith(Result(catching: { try self.channelHandler.value.sendCopyData(byteBuffer) })) |
| 24 | + } |
| 25 | + return promise.futureResult |
| 26 | + } |
| 27 | + }.whenComplete { result in |
| 28 | + continuation.resume(with: result) |
| 29 | + } |
| 30 | + } |
| 31 | + |
| 32 | + /// Send data for a `COPY ... FROM STDIN` operation to the backend. |
| 33 | + /// |
| 34 | + /// - Throws: If an error occurs during the write of if the backend sent an `ErrorResponse` during the copy |
| 35 | + /// operation, eg. to indicate that a **previous** `write` call had an invalid format. |
| 36 | + public func write(_ byteBuffer: ByteBuffer) async throws { |
| 37 | + // Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the |
| 38 | + // `writeData` closure. It is likely that the user would forget to do so. |
| 39 | + try Task.checkCancellation() |
| 40 | + |
| 41 | + try await withTaskCancellationHandler { |
| 42 | + do { |
| 43 | + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in |
| 44 | + if self.eventLoop.inEventLoop { |
| 45 | + writeAssumingInEventLoop(byteBuffer, continuation) |
| 46 | + } else { |
| 47 | + self.eventLoop.execute { |
| 48 | + writeAssumingInEventLoop(byteBuffer, continuation) |
| 49 | + } |
| 50 | + } |
| 51 | + } |
| 52 | + } catch { |
| 53 | + if Task.isCancelled { |
| 54 | + // If the task was cancelled, we might receive a postgres error which is an artifact about how we |
| 55 | + // communicate the cancellation to the state machine. Throw a `CancellationError` to the user |
| 56 | + // instead, which looks more like native Swift Concurrency code. |
| 57 | + throw CancellationError() |
| 58 | + } |
| 59 | + throw error |
| 60 | + } |
| 61 | + } onCancel: { |
| 62 | + if self.eventLoop.inEventLoop { |
| 63 | + self.channelHandler.value.cancel() |
| 64 | + } else { |
| 65 | + self.eventLoop.execute { |
| 66 | + self.channelHandler.value.cancel() |
| 67 | + } |
| 68 | + } |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to |
| 73 | + /// the backend. |
| 74 | + func done() async throws { |
| 75 | + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in |
| 76 | + if self.eventLoop.inEventLoop { |
| 77 | + self.channelHandler.value.sendCopyDone(continuation: continuation) |
| 78 | + } else { |
| 79 | + self.eventLoop.execute { |
| 80 | + self.channelHandler.value.sendCopyDone(continuation: continuation) |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to |
| 87 | + /// the backend. |
| 88 | + func failed(error: any Error) async throws { |
| 89 | + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in |
| 90 | + if self.eventLoop.inEventLoop { |
| 91 | + self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation) |
| 92 | + } else { |
| 93 | + self.eventLoop.execute { |
| 94 | + self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation) |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +/// Specifies the format in which data is transferred to the backend in a COPY operation. |
| 102 | +/// |
| 103 | +/// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings |
| 104 | +/// and their default values. |
| 105 | +public struct PostgresCopyFromFormat: Sendable { |
| 106 | + /// Options that can be used to modify the `text` format of a COPY operation. |
| 107 | + public struct TextOptions: Sendable { |
| 108 | + /// The delimiter that separates columns in the data. |
| 109 | + /// |
| 110 | + /// See the `DELIMITER` option in Postgres's `COPY` command. |
| 111 | + public var delimiter: UnicodeScalar? = nil |
| 112 | + |
| 113 | + public init() {} |
| 114 | + } |
| 115 | + |
| 116 | + enum Format { |
| 117 | + case text(TextOptions) |
| 118 | + } |
| 119 | + |
| 120 | + var format: Format |
| 121 | + |
| 122 | + public static func text(_ options: TextOptions) -> PostgresCopyFromFormat { |
| 123 | + return PostgresCopyFromFormat(format: .text(options)) |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +/// Create a `COPY ... FROM STDIN` query based on the given parameters. |
| 128 | +/// |
| 129 | +/// An empty `columns` array signifies that no columns should be specified in the query and that all columns will be |
| 130 | +/// copied by the caller. |
| 131 | +/// |
| 132 | +/// - Warning: The table and column names are inserted into the `COPY FROM` query as passed and might thus be |
| 133 | +/// susceptible to SQL injection. Ensure no untrusted data is contained in these strings. |
| 134 | +private func buildCopyFromQuery( |
| 135 | + table: String, |
| 136 | + columns: [String] = [], |
| 137 | + format: PostgresCopyFromFormat |
| 138 | +) -> PostgresQuery { |
| 139 | + var query = """ |
| 140 | + COPY "\(table)" |
| 141 | + """ |
| 142 | + if !columns.isEmpty { |
| 143 | + query += "(" |
| 144 | + query += columns.map { #""\#($0)""# }.joined(separator: ",") |
| 145 | + query += ")" |
| 146 | + } |
| 147 | + query += " FROM STDIN" |
| 148 | + var queryOptions: [String] = [] |
| 149 | + switch format.format { |
| 150 | + case .text(let options): |
| 151 | + queryOptions.append("FORMAT text") |
| 152 | + if let delimiter = options.delimiter { |
| 153 | + // Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection. |
| 154 | + queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'") |
| 155 | + } |
| 156 | + } |
| 157 | + precondition(!queryOptions.isEmpty) |
| 158 | + query += " WITH (" |
| 159 | + query += queryOptions.map { "\($0)" }.joined(separator: ",") |
| 160 | + query += ")" |
| 161 | + return "\(unescaped: query)" |
| 162 | +} |
| 163 | + |
| 164 | +extension PostgresConnection { |
| 165 | + /// Copy data into a table using a `COPY <table name> FROM STDIN` query. |
| 166 | + /// |
| 167 | + /// - Parameters: |
| 168 | + /// - table: The name of the table into which to copy the data. |
| 169 | + /// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied. |
| 170 | + /// - format: Options that specify the format of the data that is produced by `writeData`. |
| 171 | + /// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the |
| 172 | + /// writer provided by the closure to send data to the backend and return from the closure once all data is sent. |
| 173 | + /// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown |
| 174 | + /// by the `copyFrom` function. |
| 175 | + /// |
| 176 | + /// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be |
| 177 | + /// susceptible to SQL injection. Ensure no untrusted data is contained in these strings. |
| 178 | + public func copyFrom( |
| 179 | + table: String, |
| 180 | + columns: [String] = [], |
| 181 | + format: PostgresCopyFromFormat = .text(.init()), |
| 182 | + logger: Logger, |
| 183 | + file: String = #fileID, |
| 184 | + line: Int = #line, |
| 185 | + writeData: (PostgresCopyFromWriter) async throws -> Void |
| 186 | + ) async throws { |
| 187 | + var logger = logger |
| 188 | + logger[postgresMetadataKey: .connectionID] = "\(self.id)" |
| 189 | + let writer: PostgresCopyFromWriter = try await withCheckedThrowingContinuation { continuation in |
| 190 | + let context = ExtendedQueryContext( |
| 191 | + copyFromQuery: buildCopyFromQuery(table: table, columns: columns, format: format), |
| 192 | + triggerCopy: continuation, |
| 193 | + logger: logger |
| 194 | + ) |
| 195 | + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) |
| 196 | + } |
| 197 | + |
| 198 | + do { |
| 199 | + try await writeData(writer) |
| 200 | + } catch { |
| 201 | + // We need to send a `CopyFail` to the backend to put it out of copy mode. This will most likely throw, most |
| 202 | + // notably for the following two reasons. In both of them, it's better to ignore the error thrown by |
| 203 | + // `writer.failed` and instead throw the error from `writeData`: |
| 204 | + // - We send `CopyFail` and the backend replies with an `ErrorResponse` that relays the `CopyFail` message. |
| 205 | + // This took the backend out of copy mode but it's more informative to the user to see the error they |
| 206 | + // threw instead of the one that got relayed back, so it's better to ignore the error here. |
| 207 | + // - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts |
| 208 | + // the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger |
| 209 | + // a `Sync` that takes the backend out of copy mode. If `writeData` threw the error from from the |
| 210 | + // `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it doesn't |
| 211 | + // matter that we ignore the error here. If the user threw some other error, it's better to honor the |
| 212 | + // user's error. |
| 213 | + try? await writer.failed(error: error) |
| 214 | + |
| 215 | + throw error |
| 216 | + } |
| 217 | + |
| 218 | + // `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during |
| 219 | + // the transfer of the last bit of data so that the user didn't call `PostgresCopyFromWriter.write` again, which |
| 220 | + // would have checked the error state. In either of these cases, calling `writer.done` puts the backend out of |
| 221 | + // copy mode, so we don't need to send another `CopyFail`. Thus, this must not be handled in the `do` block |
| 222 | + // above. |
| 223 | + try await writer.done() |
| 224 | + } |
| 225 | +} |
0 commit comments