diff --git a/include/crow/compression.h b/include/crow/compression.h index c33cb426e..b2549068a 100644 --- a/include/crow/compression.h +++ b/include/crow/compression.h @@ -1,6 +1,11 @@ #ifdef CROW_ENABLE_COMPRESSION #pragma once +#ifndef ASIO_STANDALONE +#define ASIO_STANDALONE +#endif +#include +#include #include #include @@ -93,6 +98,152 @@ namespace crow return inflated_string; } + + class Compressor + { + public: + Compressor(bool reset_before_compress, int window_bits, int level): + reset_before_compress_(reset_before_compress), window_bits_(window_bits) + { + stream_ = std::make_unique(); + stream_->zalloc = 0; + stream_->zfree = 0; + stream_->opaque = 0; + + ::deflateInit2(stream_.get(), + level, + Z_DEFLATED, + -window_bits_, + 8, + Z_DEFAULT_STRATEGY); + } + + ~Compressor() + { + ::deflateEnd(stream_.get()); + } + + bool needs_reset() const + { + return reset_before_compress_; + } + + int window_bits() const + { + return window_bits_; + } + + std::string compress(const std::string& src) + { + if (reset_before_compress_) + { + ::deflateReset(stream_.get()); + } + + stream_->next_in = reinterpret_cast(const_cast(src.c_str())); + stream_->avail_in = src.size(); + + constexpr const uint64_t bufferSize = 8192; + asio::streambuf buffer; + do + { + asio::streambuf::mutable_buffers_type chunk = buffer.prepare(bufferSize); + + uint8_t* next_out = asio::buffer_cast(chunk); + + stream_->next_out = next_out; + stream_->avail_out = bufferSize; + + ::deflate(stream_.get(), reset_before_compress_ ? Z_FINISH : Z_SYNC_FLUSH); + + uint64_t outputSize = stream_->next_out - next_out; + buffer.commit(outputSize); + } while (stream_->avail_out == 0); + + uint64_t buffer_size = buffer.size(); + if (!reset_before_compress_) + { + buffer_size -= 4; + } + + return std::string(asio::buffer_cast(buffer.data()), buffer_size); + } + + private: + std::unique_ptr stream_; + + bool reset_before_compress_; + int window_bits_; + }; + + class Decompressor + { + public: + Decompressor(bool reset_before_decompress, int window_bits): + reset_before_decompress_(reset_before_decompress), window_bits_(window_bits) + { + stream_ = std::make_unique(); + stream_->zalloc = 0; + stream_->zfree = 0; + stream_->opaque = 0; + + ::inflateInit2(stream_.get(), -window_bits_); + } + + ~Decompressor() + { + inflateEnd(stream_.get()); + } + + bool needs_reset() const + { + return reset_before_decompress_; + } + + int window_bits() const + { + return window_bits_; + } + + std::string decompress(std::string src) + { + if (reset_before_decompress_) + { + inflateReset(stream_.get()); + } + + src.push_back('\x00'); + src.push_back('\x00'); + src.push_back('\xff'); + src.push_back('\xff'); + + stream_->next_in = reinterpret_cast(const_cast(src.c_str())); + stream_->avail_in = src.size(); + + constexpr const uint64_t bufferSize = 8192; + asio::streambuf buffer; + do + { + asio::streambuf::mutable_buffers_type chunk = buffer.prepare(bufferSize); + + uint8_t* next_out = asio::buffer_cast(chunk); + + stream_->next_out = next_out; + stream_->avail_out = bufferSize; + + ::inflate(stream_.get(), reset_before_decompress_ ? Z_FINISH : Z_SYNC_FLUSH); + buffer.commit(stream_->next_out - next_out); + } while (stream_->avail_out == 0); + + return std::string(asio::buffer_cast(buffer.data()), buffer.size()); + } + + private: + std::unique_ptr stream_; + + bool reset_before_decompress_; + int window_bits_; + }; } // namespace compression } // namespace crow diff --git a/include/crow/websocket.h b/include/crow/websocket.h index c7d8f0638..885840c18 100644 --- a/include/crow/websocket.h +++ b/include/crow/websocket.h @@ -1,10 +1,12 @@ #pragma once #include +#include #include "crow/logging.h" #include "crow/socket_adaptors.h" #include "crow/http_request.h" #include "crow/TinySHA1.hpp" #include "crow/utility.h" +#include "crow/compression.h" namespace crow { @@ -107,6 +109,17 @@ namespace crow userdata(ud); } +#ifdef CROW_ENABLE_COMPRESSION + std::string extensions_header = req.get_header_value("Sec-WebSocket-Extensions"); + if (extensions_header.find("permessage-deflate") != std::string::npos) + { + const bool reset_compressor = extensions_header.find("server_no_context_takeover") != std::string::npos; + compressor_ = std::make_unique(reset_compressor, compression::DEFLATE, Z_BEST_COMPRESSION); + const bool reset_decompressor = extensions_header.find("client_no_context_takeover") != std::string::npos; + decompressor_ = std::make_unique(reset_decompressor, compression::DEFLATE); + } +#endif + // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== // Sec-WebSocket-Version: 13 std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -186,13 +199,29 @@ namespace crow /// Send a binary encoded message. void send_binary(std::string msg) override { - send_data(0x2, std::move(msg)); + int opcode = 0x2; +#ifdef CROW_ENABLE_COMPRESSION + if (compressor_) + { + opcode += 0x40; + msg = compressor_->compress(msg); + } +#endif + send_data(opcode, std::move(msg)); } /// Send a plaintext message. void send_text(std::string msg) override { - send_data(0x1, std::move(msg)); + int opcode = 0x1; +#ifdef CROW_ENABLE_COMPRESSION + if (compressor_) + { + opcode += 0x40; + msg = compressor_->compress(msg); + } +#endif + send_data(opcode, std::move(msg)); } /// Send a close signal. @@ -265,6 +294,19 @@ namespace crow write_buffers_.emplace_back(header); write_buffers_.emplace_back(std::move(hello)); write_buffers_.emplace_back(crlf); +#ifdef CROW_ENABLE_COMPRESSION + if (compressor_ && decompressor_) + { + write_buffers_.emplace_back( + "Sec-WebSocket-Extensions: permessage-deflate" + "; server_max_window_bits=" + + std::to_string(compressor_->window_bits()) + + "; client_max_window_bits=" + std::to_string(decompressor_->window_bits()) + + (compressor_->needs_reset() ? "; server_no_context_takeover" : "") + + (decompressor_->needs_reset() ? "; client_no_context_takeover" : "")); + write_buffers_.emplace_back(crlf); + } +#endif write_buffers_.emplace_back(crlf); do_write(); if (open_handler_) @@ -528,6 +570,12 @@ namespace crow return mini_header_ & 0x8000; } + /// Check if payload is compressed + bool is_compressed() + { + return mini_header_ & 0x4000; + } + /// Extract the opcode from the header. int opcode() { @@ -555,7 +603,11 @@ namespace crow if (is_FIN()) { if (message_handler_) +#ifdef CROW_ENABLE_COMPRESSION + message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_); +#else message_handler_(*this, message_, is_binary_); +#endif message_.clear(); } } @@ -567,7 +619,11 @@ namespace crow if (is_FIN()) { if (message_handler_) +#ifdef CROW_ENABLE_COMPRESSION + message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_); +#else message_handler_(*this, message_, is_binary_); +#endif message_.clear(); } } @@ -579,7 +635,11 @@ namespace crow if (is_FIN()) { if (message_handler_) +#ifdef CROW_ENABLE_COMPRESSION + message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_); +#else message_handler_(*this, message_, is_binary_); +#endif message_.clear(); } } @@ -734,6 +794,10 @@ namespace crow std::shared_ptr anchor_ = std::make_shared(); // Value is just for placeholding +#ifdef CROW_ENABLE_COMPRESSION + std::unique_ptr compressor_; + std::unique_ptr decompressor_; +#endif std::function open_handler_; std::function message_handler_; std::function close_handler_;