Skip to content

Commit b3a585c

Browse files
committed
feat: add handleUpgradeRequest route option
1 parent b6d982a commit b3a585c

File tree

5 files changed

+245
-4
lines changed

5 files changed

+245
-4
lines changed

README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,47 @@ fastify.register(require('@fastify/websocket'), {
263263
})
264264
```
265265

266+
### Custom upgrade request handling
267+
268+
By default, `@fastify/websocket` handles upgrading incoming connections to the websocket protocol before handing off the handler you have defined.
269+
If you wish to handle upgrade events yourself you can pass your own `handleUpgradeRequest` function:
270+
271+
```js
272+
const fastify = require('fastify')()
273+
274+
fastify.register(require('@fastify/websocket'))
275+
276+
fastify.register(async function () {
277+
fastify.route({
278+
method: 'GET',
279+
url: '/hello',
280+
handleUpgradeRequest: (request, source, head) => {
281+
// handle the FastifyRequest which has triggered an upgrade event
282+
// throwing an error will abort the upgrade
283+
// return a Promise for a Websocket to proceed
284+
if (request.params.allow === "false") {
285+
const error = new Error("Upgrade not allow")
286+
error.statusCode = 403
287+
throw error
288+
} else {
289+
return new Promise((resolve) => {
290+
fastify.websocketServer.handleUpgrade(request.raw, socket, head, (ws) => {
291+
resolve(ws)
292+
})
293+
})
294+
}
295+
}
296+
wsHandler: (socket, req) => {
297+
socket.send('hello client')
298+
299+
socket.once('message', chunk => {
300+
socket.close()
301+
})
302+
}
303+
})
304+
})
305+
```
306+
266307
### Creating a stream from the WebSocket
267308

268309
```js

index.js

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ function fastifyWebsocket (fastify, opts, next) {
123123
}
124124
websocketListenServer.on('upgrade', onUpgrade)
125125

126-
const handleUpgrade = (rawRequest, callback) => {
126+
const defaultHandleUpgrade = (request, _reply, callback) => {
127+
const rawRequest = request.raw
127128
wss.handleUpgrade(rawRequest, rawRequest[kWs], rawRequest[kWsHead], (socket) => {
128129
wss.emit('connection', socket, rawRequest)
129130

@@ -155,6 +156,22 @@ function fastifyWebsocket (fastify, opts, next) {
155156
let isWebsocketRoute = false
156157
let wsHandler = routeOptions.wsHandler
157158
let handler = routeOptions.handler
159+
const handleUpgrade = routeOptions.handleUpgradeRequest
160+
? (request, reply, callback) => {
161+
const rawRequest = request.raw
162+
routeOptions.handleUpgradeRequest(request, rawRequest[kWs], rawRequest[kWsHead])
163+
.then(socket => {
164+
callback(socket)
165+
})
166+
.catch(error => {
167+
const ended = reply.raw.writableEnded || reply.raw.socket.writableEnded
168+
if (!ended) {
169+
reply.raw.statusCode = error.statusCode || 500
170+
reply.raw.end(error.message)
171+
}
172+
})
173+
}
174+
: defaultHandleUpgrade
158175

159176
if (routeOptions.websocket || routeOptions.wsHandler) {
160177
if (routeOptions.method === 'HEAD') {
@@ -188,7 +205,7 @@ function fastifyWebsocket (fastify, opts, next) {
188205
// within the route handler, we check if there has been a connection upgrade by looking at request.raw[kWs]. we need to dispatch the normal HTTP handler if not, and hijack to dispatch the websocket handler if so
189206
if (request.raw[kWs]) {
190207
reply.hijack()
191-
handleUpgrade(request.raw, socket => {
208+
const onUpgrade = (socket) => {
192209
let result
193210
try {
194211
if (isWebsocketRoute) {
@@ -203,7 +220,9 @@ function fastifyWebsocket (fastify, opts, next) {
203220
if (result && typeof result.catch === 'function') {
204221
result.catch(err => errorHandler.call(this, err, socket, request, reply))
205222
}
206-
})
223+
}
224+
225+
handleUpgrade(request, reply, onUpgrade)
207226
} else {
208227
return handler.call(this, request, reply)
209228
}

test/base.test.js

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,160 @@ test('clashing upgrade handler', async (t) => {
661661
const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)
662662
await once(ws, 'error')
663663
})
664+
665+
test('Should handleUpgradeRequest successfully', async (t) => {
666+
t.plan(4)
667+
668+
const fastify = Fastify()
669+
t.after(() => fastify.close())
670+
671+
await fastify.register(fastifyWebsocket)
672+
673+
let customUpgradeCalled = false
674+
675+
fastify.get('/', {
676+
websocket: true,
677+
handleUpgradeRequest: async (request, socket, head) => {
678+
customUpgradeCalled = true
679+
t.assert.equal(typeof socket, 'object', 'socket parameter is provided')
680+
t.assert.equal(Buffer.isBuffer(head), true, 'head parameter is a buffer')
681+
682+
return new Promise((resolve) => {
683+
fastify.websocketServer.handleUpgrade(request.raw, socket, head, (ws) => {
684+
resolve(ws)
685+
})
686+
})
687+
}
688+
}, (socket) => {
689+
socket.on('message', (data) => {
690+
socket.send(`echo: ${data}`)
691+
})
692+
t.after(() => socket.terminate())
693+
})
694+
695+
await fastify.listen({ port: 0 })
696+
697+
const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)
698+
t.after(() => ws.close())
699+
700+
await once(ws, 'open')
701+
ws.send('hello')
702+
703+
const [message] = await once(ws, 'message')
704+
t.assert.equal(message.toString(), 'echo: hello')
705+
706+
t.assert.ok(customUpgradeCalled, 'handleUpgradeRequest was called')
707+
})
708+
709+
test.only('Should handle errors thrown in handleUpgradeRequest', async (t) => {
710+
t.plan(1)
711+
712+
const fastify = Fastify()
713+
t.after(() => fastify.close())
714+
715+
await fastify.register(fastifyWebsocket)
716+
717+
fastify.get('/', {
718+
websocket: true,
719+
handleUpgradeRequest: async () => {
720+
throw new Error('Custom upgrade error')
721+
}
722+
}, () => {
723+
t.fail('websocket handler should not be called when upgrade fails')
724+
})
725+
726+
await fastify.listen({ port: 0 })
727+
728+
const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)
729+
730+
let wsErrorResolved
731+
const wsErrorPromise = new Promise((resolve) => {
732+
wsErrorResolved = resolve
733+
})
734+
735+
ws.on('error', (error) => {
736+
wsErrorResolved(error)
737+
})
738+
739+
const wsError = await wsErrorPromise
740+
741+
t.assert.equal(wsError.message, 'Unexpected server response: 500')
742+
})
743+
744+
test('Should allow for handleUpgradeRequest to send a response to the client before throwing an error', async (t) => {
745+
t.plan(1)
746+
747+
const fastify = Fastify()
748+
t.after(() => fastify.close())
749+
750+
await fastify.register(fastifyWebsocket)
751+
752+
fastify.get('/', {
753+
websocket: true,
754+
handleUpgradeRequest: async () => {
755+
const error = new Error('Forbidden')
756+
error.statusCode = 403
757+
throw error
758+
}
759+
}, () => {
760+
t.fail('websocket handler should not be called when upgrade fails')
761+
})
762+
763+
await fastify.listen({ port: 0 })
764+
765+
const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)
766+
767+
let wsErrorResolved
768+
const wsErrorPromise = new Promise((resolve) => {
769+
wsErrorResolved = resolve
770+
})
771+
772+
ws.on('error', (error) => {
773+
wsErrorResolved(error)
774+
})
775+
776+
const wsError = await wsErrorPromise
777+
778+
t.assert.equal(wsError.message, 'Unexpected server response: 403')
779+
})
780+
781+
test('Should not send a response if handleUpgradeRequest has already ended the underlying socket and thrown an error', async (t) => {
782+
t.plan(1)
783+
784+
const fastify = Fastify()
785+
t.after(() => fastify.close())
786+
787+
await fastify.register(fastifyWebsocket)
788+
789+
fastify.get('/', {
790+
websocket: true,
791+
handleUpgradeRequest: async (request, socket, head) => {
792+
socket.write('HTTP/1.1 400 Bad Request\r\n')
793+
socket.write('Connection: closed\r\n')
794+
socket.write('\r\n')
795+
socket.end()
796+
socket.destroy()
797+
798+
throw new Error('thrown after response has ended')
799+
}
800+
}, () => {
801+
t.fail('websocket handler should not be called when upgrade fails')
802+
})
803+
804+
await fastify.listen({ port: 0 })
805+
806+
const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)
807+
808+
let wsErrorResolved
809+
const wsErrorPromise = new Promise((resolve) => {
810+
wsErrorResolved = resolve
811+
})
812+
813+
ws.on('error', (error) => {
814+
wsErrorResolved(error)
815+
})
816+
817+
const wsError = await wsErrorPromise
818+
819+
t.assert.equal(wsError.message, 'Unexpected server response: 400')
820+
})

types/index.d.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { preCloseAsyncHookHandler, preCloseHookHandler } from 'fastify/types/hoo
55
import { FastifyReply } from 'fastify/types/reply'
66
import { RouteGenericInterface } from 'fastify/types/route'
77
import { IncomingMessage, Server, ServerResponse } from 'node:http'
8+
import { Duplex } from 'node:stream'
89
import * as WebSocket from 'ws'
910

1011
interface WebsocketRouteOptions<
@@ -17,6 +18,7 @@ interface WebsocketRouteOptions<
1718
Logger extends FastifyBaseLogger = FastifyBaseLogger
1819
> {
1920
wsHandler?: fastifyWebsocket.WebsocketHandler<RawServer, RawRequest, RequestGeneric, ContextConfig, SchemaCompiler, TypeProvider, Logger>;
21+
handleUpgradeRequest?: (request: FastifyRequest<RequestGeneric, RawServer, RawRequest, SchemaCompiler, TypeProvider, ContextConfig, Logger>, rawSocket: Duplex, socketHead: Buffer) => Promise<WebSocket.WebSocket>;
2022
}
2123

2224
declare module 'fastify' {

types/index.test-d.ts

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ import fastify, { FastifyBaseLogger, FastifyInstance, FastifyReply, FastifyReque
44
import { RouteGenericInterface } from 'fastify/types/route'
55
import type { IncomingMessage } from 'node:http'
66
import { expectType } from 'tsd'
7-
import { Server } from 'ws'
7+
import { Server, WebSocket as BaseWebSocket } from 'ws'
88
// eslint-disable-next-line import-x/no-named-default -- Test default export
99
import fastifyWebsocket, { default as defaultFastifyWebsocket, fastifyWebsocket as namedFastifyWebsocket, WebSocket, WebsocketHandler } from '..'
10+
import { Duplex } from 'node:stream'
1011

1112
const app: FastifyInstance = fastify()
1213
app.register(fastifyWebsocket)
@@ -82,6 +83,27 @@ const augmentedRouteOptions: RouteOptions = {
8283
}
8384
app.route(augmentedRouteOptions)
8485

86+
const handleUpgradeRequestOptions: RouteOptions = {
87+
method: 'GET',
88+
url: '/route-with-handle-upgrade-request',
89+
handler: (request, reply) => {
90+
expectType<FastifyRequest>(request)
91+
expectType<FastifyReply>(reply)
92+
},
93+
handleUpgradeRequest: (request, socket, head) => {
94+
expectType<FastifyRequest>(request)
95+
expectType<Duplex>(socket)
96+
expectType<Buffer>(head)
97+
98+
return Promise.resolve(new BaseWebSocket('ws://localhost:8080'))
99+
},
100+
wsHandler: (socket, request) => {
101+
expectType<WebSocket>(socket)
102+
expectType<FastifyRequest<RouteGenericInterface>>(request)
103+
},
104+
}
105+
app.route(handleUpgradeRequestOptions)
106+
85107
app.get<{ Params: { foo: string }, Body: { bar: string }, Querystring: { search: string }, Headers: { auth: string } }>('/shorthand-explicit-types', {
86108
websocket: true
87109
}, async (socket, request) => {

0 commit comments

Comments
 (0)