mirror of
https://github.com/mastodon/mastodon.git
synced 2025-01-02 21:23:18 +00:00
Streaming: Rework websocket server initialisation & authentication code (#28631)
This commit is contained in:
parent
e72676e83a
commit
58830be943
|
@ -182,14 +182,74 @@ const CHANNEL_NAMES = [
|
|||
];
|
||||
|
||||
const startServer = async () => {
|
||||
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
|
||||
const server = http.createServer();
|
||||
const wss = new WebSocket.Server({ noServer: true });
|
||||
|
||||
// Set the X-Request-Id header on WebSockets:
|
||||
wss.on("headers", function onHeaders(headers, req) {
|
||||
headers.push(`X-Request-Id: ${req.id}`);
|
||||
});
|
||||
|
||||
const app = express();
|
||||
|
||||
app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
|
||||
|
||||
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
|
||||
const server = http.createServer(app);
|
||||
app.use(cors());
|
||||
|
||||
// Handle eventsource & other http requests:
|
||||
server.on('request', app);
|
||||
|
||||
// Handle upgrade requests:
|
||||
server.on('upgrade', async function handleUpgrade(request, socket, head) {
|
||||
/** @param {Error} err */
|
||||
const onSocketError = (err) => {
|
||||
log.error(`Error with websocket upgrade: ${err}`);
|
||||
};
|
||||
|
||||
socket.on('error', onSocketError);
|
||||
|
||||
// Authenticate:
|
||||
try {
|
||||
await accountFromRequest(request);
|
||||
} catch (err) {
|
||||
log.error(`Error authenticating request: ${err}`);
|
||||
|
||||
// Unfortunately for using the on('upgrade') setup, we need to manually
|
||||
// write a HTTP Response to the Socket to close the connection upgrade
|
||||
// attempt, so the following code is to handle all of that.
|
||||
const statusCode = err.status ?? 401;
|
||||
|
||||
/** @type {Record<string, string | number>} */
|
||||
const headers = {
|
||||
'Connection': 'close',
|
||||
'Content-Type': 'text/plain',
|
||||
'Content-Length': 0,
|
||||
'X-Request-Id': request.id,
|
||||
// TODO: Send the error message via header so it can be debugged in
|
||||
// developer tools
|
||||
};
|
||||
|
||||
// Ensure the socket is closed once we've finished writing to it:
|
||||
socket.once('finish', () => {
|
||||
socket.destroy();
|
||||
});
|
||||
|
||||
// Write the HTTP response manually:
|
||||
socket.end(`HTTP/1.1 ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n${Object.keys(headers).map((key) => `${key}: ${headers[key]}`).join('\r\n')}\r\n\r\n`);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
wss.handleUpgrade(request, socket, head, function done(ws) {
|
||||
// Remove the error handler:
|
||||
socket.removeListener('error', onSocketError);
|
||||
|
||||
// Start the connection:
|
||||
wss.emit('connection', ws, request);
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* @type {Object.<string, Array.<function(Object<string, any>): void>>}
|
||||
*/
|
||||
|
@ -360,10 +420,19 @@ const startServer = async () => {
|
|||
const isInScope = (req, necessaryScopes) =>
|
||||
req.scopes.some(scope => necessaryScopes.includes(scope));
|
||||
|
||||
/**
|
||||
* @typedef ResolvedAccount
|
||||
* @property {string} accessTokenId
|
||||
* @property {string[]} scopes
|
||||
* @property {string} accountId
|
||||
* @property {string[]} chosenLanguages
|
||||
* @property {string} deviceId
|
||||
*/
|
||||
|
||||
/**
|
||||
* @param {string} token
|
||||
* @param {any} req
|
||||
* @returns {Promise.<void>}
|
||||
* @returns {Promise<ResolvedAccount>}
|
||||
*/
|
||||
const accountFromToken = (token, req) => new Promise((resolve, reject) => {
|
||||
pgPool.connect((err, client, done) => {
|
||||
|
@ -394,14 +463,20 @@ const startServer = async () => {
|
|||
req.chosenLanguages = result.rows[0].chosen_languages;
|
||||
req.deviceId = result.rows[0].device_id;
|
||||
|
||||
resolve();
|
||||
resolve({
|
||||
accessTokenId: result.rows[0].id,
|
||||
scopes: result.rows[0].scopes.split(' '),
|
||||
accountId: result.rows[0].account_id,
|
||||
chosenLanguages: result.rows[0].chosen_languages,
|
||||
deviceId: result.rows[0].device_id
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* @param {any} req
|
||||
* @returns {Promise.<void>}
|
||||
* @returns {Promise<ResolvedAccount>}
|
||||
*/
|
||||
const accountFromRequest = (req) => new Promise((resolve, reject) => {
|
||||
const authorization = req.headers.authorization;
|
||||
|
@ -494,25 +569,6 @@ const startServer = async () => {
|
|||
reject(err);
|
||||
});
|
||||
|
||||
/**
|
||||
* @param {any} info
|
||||
* @param {function(boolean, number, string): void} callback
|
||||
*/
|
||||
const wsVerifyClient = (info, callback) => {
|
||||
// When verifying the websockets connection, we no longer pre-emptively
|
||||
// check OAuth scopes and drop the connection if they're missing. We only
|
||||
// drop the connection if access without token is not allowed by environment
|
||||
// variables. OAuth scope checks are moved to the point of subscription
|
||||
// to a specific stream.
|
||||
|
||||
accountFromRequest(info.req).then(() => {
|
||||
callback(true, undefined, undefined);
|
||||
}).catch(err => {
|
||||
log.error(info.req.requestId, err.toString());
|
||||
callback(false, 401, 'Unauthorized');
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* @typedef SystemMessageHandlers
|
||||
* @property {function(): void} onKill
|
||||
|
@ -944,8 +1000,8 @@ const startServer = async () => {
|
|||
};
|
||||
|
||||
/**
|
||||
* @param {any} req
|
||||
* @param {any} ws
|
||||
* @param {http.IncomingMessage} req
|
||||
* @param {WebSocket} ws
|
||||
* @param {string[]} streamName
|
||||
* @returns {function(string, string): void}
|
||||
*/
|
||||
|
@ -955,7 +1011,9 @@ const startServer = async () => {
|
|||
return;
|
||||
}
|
||||
|
||||
ws.send(JSON.stringify({ stream: streamName, event, payload }), (err) => {
|
||||
const message = JSON.stringify({ stream: streamName, event, payload });
|
||||
|
||||
ws.send(message, (/** @type {Error} */ err) => {
|
||||
if (err) {
|
||||
log.error(req.requestId, `Failed to send to websocket: ${err}`);
|
||||
}
|
||||
|
@ -992,8 +1050,6 @@ const startServer = async () => {
|
|||
});
|
||||
});
|
||||
|
||||
const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
|
||||
|
||||
/**
|
||||
* @typedef StreamParams
|
||||
* @property {string} [tag]
|
||||
|
@ -1173,8 +1229,8 @@ const startServer = async () => {
|
|||
|
||||
/**
|
||||
* @typedef WebSocketSession
|
||||
* @property {any} socket
|
||||
* @property {any} request
|
||||
* @property {WebSocket} websocket
|
||||
* @property {http.IncomingMessage} request
|
||||
* @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
|
||||
*/
|
||||
|
||||
|
@ -1297,7 +1353,11 @@ const startServer = async () => {
|
|||
}
|
||||
};
|
||||
|
||||
wss.on('connection', (ws, req) => {
|
||||
/**
|
||||
* @param {WebSocket & { isAlive: boolean }} ws
|
||||
* @param {http.IncomingMessage} req
|
||||
*/
|
||||
function onConnection(ws, req) {
|
||||
// Note: url.parse could throw, which would terminate the connection, so we
|
||||
// increment the connected clients metric straight away when we establish
|
||||
// the connection, without waiting:
|
||||
|
@ -1375,7 +1435,9 @@ const startServer = async () => {
|
|||
if (location && location.query.stream) {
|
||||
subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
wss.on('connection', onConnection);
|
||||
|
||||
setInterval(() => {
|
||||
wss.clients.forEach(ws => {
|
||||
|
|
Loading…
Reference in a new issue