From 5ad52377a165abd05512ff0ecbb4ba5ff9a243ab Mon Sep 17 00:00:00 2001 From: gary-connor256 Date: Mon, 28 Apr 2025 11:10:57 +0900 Subject: [PATCH 1/2] added: websocket module --- ws.lua | 395 +++++++++++++++++++++++++++++++++++++++++++++++++ ws/packets.lua | 176 ++++++++++++++++++++++ ws/parser.lua | 41 +++++ wshandler.lua | 78 ++++++++++ 4 files changed, 690 insertions(+) create mode 100644 ws.lua create mode 100644 ws/packets.lua create mode 100644 ws/parser.lua create mode 100644 wshandler.lua diff --git a/ws.lua b/ws.lua new file mode 100644 index 0000000..97198c7 --- /dev/null +++ b/ws.lua @@ -0,0 +1,395 @@ +require "strmproxy.utils.stringUtils" +require "strmproxy.utils.pureluapack" +local event=require "strmproxy.utils.event" +local logger=require "strmproxy.utils.compatibleLog" +local format = string.format +local bit = require "bit" +local ffi = require "ffi" + +local byte = string.byte +local band = bit.band +local bor = bit.bor +local bxor = bit.bxor +local lshift = bit.lshift +local rshift = bit.rshift +local ffi_string = ffi.string + + +local _M = {} +_M._PROTOCOL ='ws' + +local protocolPacket=require ("strmproxy.".. _M._PROTOCOL .. ".packets") + +function _M.new(self) + local o = setmetatable({},{__index=self}) + o.c2p_stage = "INIT" + o.p2s_stage = "INIT" + + o.HandshakeRequestEvent=event:newReturnEvent(o,"HandshakeRequestEvent") + o.HandshakeResponseEvent=event:newReturnEvent(o,"HandshakeResponseEvent") + o.FrameEvent=event:newReturnEvent(o,"FrameEvent") + o.ctx={} + + local parser=require ("strmproxy.".. _M._PROTOCOL ..".parser"):new() + + o.C2PParser = parser.C2PParser + o.C2PParser.events.TextEvent:addHandler(o, self.OnUpTextEvent) + o.C2PParser.events.BinaryEvent:addHandler(o, self.OnUpBinaryEvent) + o.C2PParser.events.CloseEvent:addHandler(o, self.OnUpCloseEvent) + o.C2PParser.events.PingEvent:addHandler(o, self.OnUpPingEvent) + o.C2PParser.events.PongEvent:addHandler(o, self.OnUpPongEvent) + + o.S2PParser = parser.S2PParser + o.S2PParser.events.TextEvent:addHandler(o, self.OnDownTextEvent) + o.S2PParser.events.BinaryEvent:addHandler(o, self.OnDownBinaryEvent) + o.S2PParser.events.CloseEvent:addHandler(o, self.OnDownCloseEvent) + o.S2PParser.events.PingEvent:addHandler(o, self.OnDownPingEvent) + o.S2PParser.events.PongEvent:addHandler(o, self.OnDownPongEvent) + + return o +end + + +local function parse_handshake(headers) + local handshake = {} + + -- Split headers by newlines + for line in headers:gmatch("[^\r\n]+") do + local key, value = line:match("^(%S+):%s*(.+)$") + if key and value then + handshake[key] = value + end + end + + return handshake +end + +local function readHandshakeRequest(self, sock) + local req_line, err + local headers = {} + + -- Read the HTTP request from the client + req_line, err = sock:receive("*l") + if err then return nil,nil,err end + + -- Read headers until an empty line is encountered + while true do + local line = sock:receive("*l") + if line == "" then break end + headers[#headers + 1] = line + end + + -- Combine headers into a single string + local headers_string = table.concat(headers, "\r\n") + + -- Parse the handshake headers + local handshake = parse_handshake(headers_string) + + -- Check for required headers + if handshake["Upgrade"] and handshake["Connection"]:lower():find("upgrade") then + handshake["Host"] = self.channel.upstream.ip + if self.channel.upstream.ssl then + handshake["Host"]=handshake["Host"].. + (self.channel.upstream.port ~= 443 and ":"..self.channel.upstream.port or "") + else + handshake["Host"]=handshake["Host"].. + (self.channel.upstream.port ~= 80 and ":"..self.channel.upstream.port or "") + end + + handshake["Sec-WebSocket-Extensions"] = nil + + headers_string = "" + for key, value in pairs(handshake) do + headers_string=headers_string..key..": "..value.."\r\n" + end + else + -- Handle invalid handshake + err ="HTTP/1.1 400 Bad Request\r\n"..headers_string + end + + handshake["req_line"] = req_line + local allBytes = req_line.."\r\n"..headers_string.."\r\n" + + return allBytes, handshake, err +end + +local function readHandshakeResponse(self, sock) + local resp_line, err + local headers = {} + + -- Read the HTTP response from the server + resp_line, err = sock:receive("*l") + if err then return nil,nil,err end + + -- Read headers until an empty line is encountered + while true do + local line = sock:receive("*l") + if line == "" then break end + headers[#headers + 1] = line + end + + -- Combine headers into a single string + local headers_string = table.concat(headers, "\r\n") + + -- Parse the handshake headers + local handshake = parse_handshake(headers_string) + + -- Check for required headers + if handshake["Upgrade"] and handshake["Connection"]:lower():find("upgrade") then + else + -- Handle invalid handshake + err ="HTTP/1.1 400 Bad Request\r\n"..headers_string + end + + handshake["resp_line"] = resp_line + local allBytes = resp_line.."\r\n"..headers_string.."\r\n\r\n" + + return allBytes, handshake, err +end + +---------------parser event handlers---------------------- +function _M:OnUpTextEvent(source, packet, up) + self.FrameEvent:trigger({type="text", payload=packet.payload, up="up"}, self.ctx) +end + +function _M:OnUpBinaryEvent(source, packet, up) + self.FrameEvent:trigger({type="binary", payload=packet.payload, up="up"}, self.ctx) +end + +function _M:OnUpCloseEvent(source, packet, up) + self.FrameEvent:trigger({type="close", payload=packet.payload, up="up"}, self.ctx) +end + +function _M:OnUpPingEvent(source, packet, up) + self.FrameEvent:trigger({type="ping", payload=packet.payload, up="up"}, self.ctx) +end + +function _M:OnUpPongEvent(source, packet, up) + self.FrameEvent:trigger({type="pong", payload=packet.payload, up="up"}, self.ctx) +end + +function _M:OnDownTextEvent(source, packet, up) + self.FrameEvent:trigger({type="text", payload=packet.payload, up="down"}, self.ctx) +end + +function _M:OnDownBinaryEvent(source, packet, up) + self.FrameEvent:trigger({type="binary", payload=packet.payload, up="down"}, self.ctx) +end + +function _M:OnDownCloseEvent(source, packet, up) + self.FrameEvent:trigger({type="close", payload=packet.payload, up="down"}, self.ctx) +end + +function _M:OnDownPingEvent(source, packet, up) + self.FrameEvent:trigger({type="ping", payload=packet.payload, up="down"}, self.ctx) +end + +function _M:OnDownPongEvent(source, packet, up) + self.FrameEvent:trigger({type="pong", payload=packet.payload, up="down"}, self.ctx) +end + +---------------receive and parse packet---------------------- +local function recv(self, readMethod, max_payload_len, force_masking, up) + local allBytes + + local data, err = readMethod(self.channel, 2) + if not data then + return nil, "failed to receive the first 2 bytes: " .. err + end + + allBytes = data + + local fst, snd = byte(data, 1, 2) + local fin = band(fst, 0x80) ~= 0 + -- print("fin: ", fin) + + -- if band(fst, 0x70) ~= 0 then + -- return nil, "bad RSV1, RSV2, or RSV3 bits" + -- end + + local opcode = band(fst, 0x0f) + -- print("opcode: ", tohex(opcode)) + + if opcode >= 0x3 and opcode <= 0x7 then + return nil, "reserved non-control frames" + end + + if opcode >= 0xb and opcode <= 0xf then + return nil, "reserved control frames" + end + + local mask = band(snd, 0x80) ~= 0 + + if force_masking and not mask then + return nil, "frame unmasked" + end + + local payload_len = band(snd, 0x7f) + -- print("payload len: ", payload_len) + + if payload_len == 126 then + local data, err = readMethod(self.channel, 2) + if not data then + return nil, "failed to receive the 2 byte payload length: " + .. (err or "unknown") + end + + allBytes = allBytes..data + + payload_len = bor(lshift(byte(data, 1), 8), byte(data, 2)) + + elseif payload_len == 127 then + local data, err = readMethod(self.channel, 8) + if not data then + return nil, "failed to receive the 8 byte payload length: " + .. (err or "unknown") + end + + allBytes = allBytes..data + + if byte(data, 1) ~= 0 + or byte(data, 2) ~= 0 + or byte(data, 3) ~= 0 + or byte(data, 4) ~= 0 + then + return nil, "payload len too large" + end + + local fifth = byte(data, 5) + if band(fifth, 0x80) ~= 0 then + return nil, "payload len too large" + end + + payload_len = bor(lshift(fifth, 24), + lshift(byte(data, 6), 16), + lshift(byte(data, 7), 8), + byte(data, 8)) + end + + if band(opcode, 0x8) ~= 0 then + -- being a control frame + if payload_len > 125 then + return nil, "too long payload for control frame" + end + + if not fin then + return nil, "fragmented control frame" + end + end + + -- print("payload len: ", payload_len, ", max payload len: ", + -- max_payload_len) + + if payload_len > max_payload_len then + return nil, "exceeding max payload len" + end + + local rest + if mask then + rest = payload_len + 4 + else + rest = payload_len + end + -- print("rest: ", rest) + + local data + if rest > 0 then + data, err = readMethod(self.channel, rest) + if not data then + return nil, "failed to read masking-len and payload: " + .. (err or "unknown") + end + + allBytes = allBytes..data + else + data = "" + end + + -- print("received rest") + + if opcode == 0x8 then + -- being a close frame + if payload_len > 0 then + if payload_len < 2 then + return nil, "close frame with a body must carry a 2-byte" + .. " status code" + end + end + end + + local parser = up and self.C2PParser or self.S2PParser + local packet, err = parser:parse(allBytes, nil, opcode, data) + return packet, err +end + +function _M.processUpRequest(self) + + if (self.c2p_stage == "INIT") then + + logger.dbg("Websocket>: readHandshakeRequest()") + local allBytes, handshake, err = readHandshakeRequest(self, self.channel.c2pSock) + if err then + logger.err("Websocket>: -- Failed to read handshake request: ", err) + return nil,err + end + self.c2p_stage = "HANDSHAKE" + self.HandshakeRequestEvent:trigger(handshake, self.ctx) + return allBytes + + else + + local packet, err = recv(self, self.channel.c2pRead, 65536, false, true) + if err then return nil,err end + return packet.allBytes + + -- TEST: we can block/replace data payload + -- TODO: check if this is the correct way to do this + -- local c2pPacket = self.C2PParser.parserList[packet.code].parser:new({ + -- fin=packet.fin, + -- mask=packet.mask, + -- payload=packet.payload ... "[777]" + -- }):pack() + -- self:sendUp(c2pPacket.allBytes) + -- return + + end + +end + + + +function _M.processDownRequest(self) + + if (self.p2s_stage == "INIT") then + + logger.dbg("Websocket>: readHandshakeResponse()") + local allBytes, handshake, err = readHandshakeResponse(self, self.channel.p2sSock) + if err then + logger.err("Websocket>: -- Failed to read handshake response: ", err) + return nil,err + end + self.p2s_stage = "OK" + self.c2p_stage = "OK" + self.HandshakeResponseEvent:trigger(handshake, self.ctx) + return allBytes + + else + + local packet, err = recv(self, self.channel.p2sRead, 65536, false, false) + if err then return nil,err end + return packet.allBytes + + -- TEST: we can block/replace data payload + -- TODO: check if this is the correct way to do this + -- local s2pPacket = self.S2PParser.parserList[packet.code].parser:new({ + -- fin=packet.fin, + -- mask=packet.mask, + -- payload=packet.payload .. "[888]" + -- }):pack() + -- self:sendDown(s2pPacket.allBytes) + -- return + end +end + + +return _M \ No newline at end of file diff --git a/ws/packets.lua b/ws/packets.lua new file mode 100644 index 0000000..7530074 --- /dev/null +++ b/ws/packets.lua @@ -0,0 +1,176 @@ +require "strmproxy.utils.stringUtils" +require "strmproxy.utils.pureluapack" +local tableUtils=require "strmproxy.utils.tableUtils" +local extends=tableUtils.extends +local orderTable=tableUtils.OrderedTable +local bit = require "bit" +local ffi = require "ffi" +local zlib = require("zlib") + +local byte = string.byte +local char = string.char +local sub = string.sub +local band = bit.band +local bor = bit.bor +local bxor = bit.bxor +local lshift = bit.lshift +local rshift = bit.rshift +local rand = math.random +local ffi_new = ffi.new +local ffi_string = ffi.string + +local _M={} + +--Packet type defines, only the type that have been implemented are listed +_M.PktType={ + Continuation = 0x0, + Text = 0x1, + Binary = 0x2, + Close = 0x8, + Ping = 0x9, + Pong = 0xA, +} + +local str_buf_size = 4096 +local str_buf +local c_buf_type = ffi.typeof("char[?]") + +local function get_string_buf(size) + if size > str_buf_size then + return ffi_new(c_buf_type, size) + end + if not str_buf then + str_buf = ffi_new(c_buf_type, str_buf_size) + end + + return str_buf +end + +_M.Base = { + parse = function(self, allBytes, pos, data) + self.allBytes=allBytes + self:parsePayload(allBytes, pos, data) + return self + end, + + parsePayload = function(self, allBytes, pos, data) + local fst, snd = byte(allBytes, 1, 2) + self.opcode = band(fst, 0x0f) + self.fin = band(fst, 0x80) ~= 0 + self.mask = band(snd, 0x80) ~= 0 + + if self.mask then + local payload_len = #data - 4 + -- TODO string.buffer optimizations + local bytes = get_string_buf(payload_len) + for i = 1, payload_len do + bytes[i - 1] = bxor(byte(data, 4 + i), + byte(data, (i - 1) % 4 + 1)) + end + self.payload = ffi_string(bytes, payload_len) + else + self.payload = data + end + + return self + end, + + pack = function (self) + self.allBytes = self:packPayload() + return self + end, + + packPayload = function (self) + local fst + if self.fin then + fst = bor(0x80, self.code) + else + fst = self.code + end + + local payload_len = #self.payload + local snd, extra_len_bytes + if payload_len <= 125 then + snd = payload_len + extra_len_bytes = "" + + elseif payload_len <= 65535 then + snd = 126 + extra_len_bytes = char(band(rshift(payload_len, 8), 0xff), + band(payload_len, 0xff)) + + else + if band(payload_len, 0x7fffffff) < payload_len then + return nil, "payload too big" + end + + snd = 127 + -- XXX we only support 31-bit length here + extra_len_bytes = char(0, 0, 0, 0, band(rshift(payload_len, 24), 0xff), + band(rshift(payload_len, 16), 0xff), + band(rshift(payload_len, 8), 0xff), + band(payload_len, 0xff)) + end + + local masking_key + if self.mask then + -- set the mask bit + snd = bor(snd, 0x80) + local key = rand(0xffffffff) + masking_key = char(band(rshift(key, 24), 0xff), + band(rshift(key, 16), 0xff), + band(rshift(key, 8), 0xff), + band(key, 0xff)) + + -- TODO string.buffer optimizations + local bytes = get_string_buf(payload_len) + for i = 1, payload_len do + bytes[i - 1] = bxor(byte(self.payload, i), + byte(masking_key, (i - 1) % 4 + 1)) + end + self.payload = ffi_string(bytes, payload_len) + + else + masking_key = "" + end + + return char(fst, snd) .. extra_len_bytes .. masking_key .. self.payload + end, + + new=function(self,o) + local o=o or {} + return orderTable.new(self,o) + end +} + +_M.Continuation={ + code=_M.PktType.Continuation, +} +extends(_M.Continuation,_M.Base) + +_M.Text={ + code=_M.PktType.Text, +} +extends(_M.Text,_M.Base) + +_M.Binary={ + code=_M.PktType.Binary, +} +extends(_M.Binary,_M.Base) + +_M.Close={ + code=_M.PktType.Close, +} +extends(_M.Close,_M.Base) + +_M.Ping={ + code=_M.PktType.Ping, +} +extends(_M.Ping,_M.Base) + +_M.Pong={ + code=_M.PktType.Pong, +} +extends(_M.Pong,_M.Base) + +return _M \ No newline at end of file diff --git a/ws/parser.lua b/ws/parser.lua new file mode 100644 index 0000000..a77338f --- /dev/null +++ b/ws/parser.lua @@ -0,0 +1,41 @@ +local bit = require "bit" + +local P=require "strmproxy.ws.packets" +local parser=require("strmproxy.parser") + +local byte = string.byte +local band = bit.band + +local _M={} + +-- only define the parsers that are needed to parse in the Websocket protocol +local conf={ + {key=P.PktType.Continuation, parserName="Continuation", parser=P.Continuation, eventName="ContinuationEvent"}, + {key=P.PktType.Text, parserName="Text", parser=P.Text, eventName="TextEvent"}, + {key=P.PktType.Binary, parserName="Binary", parser=P.Binary, eventName="BinaryEvent"}, + {key=P.PktType.Close, parserName="Close", parser=P.Close, eventName="CloseEvent"}, + {key=P.PktType.Ping, parserName="Ping", parser=P.Ping, eventName="PingEvent"}, + {key=P.PktType.Pong, parserName="Pong", parser=P.Pong, eventName="PongEvent"}, +} + +local KeyG=function(allBytes) + return band(byte(allBytes, 1), 0x0f) +end + +function _M:new() + local o= setmetatable({},{__index=self}) + local C2PParser=parser:new() + C2PParser.keyGenerator=keyG + C2PParser:registerMulti(conf) + C2PParser:registerDefaultParser(P.Base) + o.C2PParser=C2PParser + + local S2PParser=parser:new() + S2PParser.keyGenerator=keyG + S2PParser:registerMulti(conf) + S2PParser:registerDefaultParser(P.Base) + o.S2PParser=S2PParser + return o +end + +return _M \ No newline at end of file diff --git a/wshandler.lua b/wshandler.lua new file mode 100644 index 0000000..32059d0 --- /dev/null +++ b/wshandler.lua @@ -0,0 +1,78 @@ +local logger = require "strmproxy.utils.compatibleLog" +local sockLogger = require "resty.logger.socket" +local format = string.format + +local _M = {} +_M._PROTOCOL = "WS" + +if not sockLogger.initted() then + local ok, err = sockLogger.init { + -- logger server address + host = '127.0.0.1', + port = 12080, + flush_limit = 10, + drop_limit = 567800, + } + if not ok then + logger.err("failed to initialize the logger: ", err) + end +else + logger.err("logger module already initialized") +end + +local function wsLog(data) + if sockLogger then + local bytes, err = sockLogger.log(data) + if err then + logger.err("failed to log reply: ", err) + end + else + logger.dbg( data) + end +end + +local function OnConnect(context, source, session) + if session then + local log = format("[".._M._PROTOCOL .. "] connected from %s:%s to %s:%s\r\n", session.clientIP, session.clientPort, session.srvIP, session.srvPort) + wsLog(log) + else + logger.dbg("session is nil") + end +end + +local function OnHandshakeRequestEvent(context, source, headers) + local log = format('['.._M._PROTOCOL .. "] Handshake Request: %s %s %s\r\n", + headers["req_line"], headers["Connection"], headers["Upgrade"]) + logger.dbg("[" .. _M._PROTOCOL .. " ] ", log) + wsLog(log) +end + +local function OnHandshakeResponseEvent(context, source, headers) + local log = format('['.._M._PROTOCOL .. "] Handshake Response: %s %s %s\r\n", + headers["resp_line"], headers["Connection"], headers["Upgrade"]) + logger.dbg("[" .. _M._PROTOCOL .. " ] ", log) + wsLog(log) +end + +local function OnFrameEvent(context, source, packet) + local log + if (packet.up == "up") then + log = format('['.._M._PROTOCOL .. "] %s:%s sent \t\t[%s] %s\r\n", + source.ctx.clientIP, source.ctx.clientPort, + packet.type, packet.payload) + else + log = format('['.._M._PROTOCOL .. "] received from %s:%s \t[%s] %s\r\n", + source.ctx.srvIP, source.ctx.srvPort, + packet.type, packet.payload) + end + + logger.dbg("[" .. _M._PROTOCOL .. " ] ", log) + wsLog(log) +end + +_M.OnConnect = OnConnect +_M.OnHandshakeRequestEvent = OnHandshakeRequestEvent +_M.OnHandshakeResponseEvent = OnHandshakeResponseEvent +_M.OnFrameEvent = OnFrameEvent + +return _M \ No newline at end of file From cd84aa888bed26d71bd53b724041c1a734df29ba Mon Sep 17 00:00:00 2001 From: novaecho256 Date: Thu, 15 May 2025 17:36:05 +0900 Subject: [PATCH 2/2] update websocket module --- channel.lua | 540 ++++++++++++++++++++++++++++++++++--------------- ws.lua | 84 +++----- ws/packets.lua | 4 +- wshandler.lua | 9 +- 4 files changed, 414 insertions(+), 223 deletions(-) diff --git a/channel.lua b/channel.lua index b15ea2f..44b5f6d 100644 --- a/channel.lua +++ b/channel.lua @@ -1,163 +1,377 @@ -local sub = string.sub local byte = string.byte local format = string.format local tcp = ngx.socket.tcp local setmetatable = setmetatable local spawn = ngx.thread.spawn local wait = ngx.thread.wait local logger = require "suproxy.utils.compatibleLog" local ses= require "suproxy.session.session" local cjson=require "cjson" -local event=require "suproxy.utils.event" local balancer=require "suproxy.balancer.balancer" local _M={} - -_M._VERSION = '0.01' - - -function _M:new(upstreams,processor,options) local o={} options =options or {} options.c2pConnTimeout=options.c2pConnTimeout or 10000 options.c2pSendTimeout=options.c2pSendTimeout or 10000 options.c2pReadTimeout=options.c2pReadTimeout or 3600000 options.p2sConnTimeout=options.p2sConnTimeout or 10000 options.p2sSendTimeout=options.p2sSendTimeout or 10000 options.p2sReadTimeout=options.p2sReadTimeout or 3600000 - local c2pSock, err = ngx.req.socket() - if not c2pSock then - return nil, err - end - c2pSock:settimeouts(options.c2pConnTimeout , options.c2pSendTimeout , options.c2pReadTimeout) - local standalone=false - if(not upstreams) then - logger.log(logger.ERR, format("[SuProxy] no upstream specified, Proxy will run in standalone mode")) - standalone=true - end - local p2sSock=nil - if(not standalone) then - p2sSock, err = tcp() - if not p2sSock then - return nil, err - end - p2sSock:settimeouts(options.p2sConnTimeout , options.p2sSendTimeout , options.p2sReadTimeout ) - end - --add default receive-then-forward processor - if(not processor and not standalone) then - processor={} - processor.processUpRequest=function(self) - local data, err, partial =self.channel:c2pRead(1024*10) --real error happend or timeout if not data and not partial and err then return nil,err end - if(data and not err) then - return data - else - return partial - end - end - processor.processDownRequest=function(self) - local data, err, partial = self.channel:p2sRead(1024*10) --real error happend or timeout if not data and not partial and err then return nil,err end - if(data and not err) then - return data - else - return partial - end - end - end - --add default echo processor if proxy in standalone mode - if(not processor and standalone) then - processor={} - processor.processUpRequest=function(self) - local data, err, partial =self.channel:c2pRead(1024*10) - --real error happend or timeout if not data and not partial and err then return nil,err end - local echodata="" - if(data and not err) then - echodata=data - else - echodata=partial - end - logger.log(logger.INFO,echodata) - local _,err=self.channel:c2pSend(echodata) - logger.log(logger.ERR,partial) - end - end - local upForwarder=function(self,data) - if data then return self.channel:p2sSend(data) end - end - local downForwarder=function(self,data) - if data then return self.channel:c2pSend(data) end - end - --add default upforwarder - processor.sendUp=processor.sendUp or upForwarder - --add default downforwarder - processor.sendDown=processor.sendDown or downForwarder - processor.ctx=processor.ctx or {} local sessionInvalidHandler=function (self,session) logger.log(logger.DEBUG,"session closed") self:shutdown() end --set default session invalid handler processor.sessionInvalid=processor.sessionInvalid or sessionInvalidHandler --set AuthSuccessEvent handler if processor.AuthSuccessEvent then processor.AuthSuccessEvent:addHandler(o,function(self,source,username) if self.session and username then self.session.uid=username end end) end --update ctx info to session if processor.ContextUpdateEvent then processor.ContextUpdateEvent:addHandler(o,function(self,source,ctx) if ctx and self.session then self.session.ctx=ctx end end) end o.p2sSock=p2sSock o.c2pSock=c2pSock o.processor=processor o.balancer=upstreams.getBest and upstreams or balancer:new(upstreams) o.standalone=standalone o.OnConnectEvent=event:new(o,"OnConnectEvent") o.sessionMan=options.sessionMan or ses:newDoNothing() setmetatable(o, { __index = self }) processor.channel=o return o -end - local function _cleanup(self) - logger.log(logger.DEBUG, format("[SuProxy] clean up executed")) - -- make sure buffers are clean - ngx.flush(true) - local p2sSock = self.p2sSock - local c2pSock = self.c2pSock - if p2sSock ~= nil then - if p2sSock.shutdown then - p2sSock:shutdown("send") - end - if p2sSock.close ~= nil then - local ok, err = p2sSock:setkeepalive() - if not ok then - -- - end - end - end - - if c2pSock ~= nil then - if c2pSock.shutdown then - c2pSock:shutdown("send") - end - if c2pSock.close ~= nil then - local ok, err = c2pSock:close() - if not ok then - -- - end - end - end - -end - local function _upl(self) - -- proxy client request to server local upstream=self.upstream - local buf, err, partial local session,err=ses:new(self.processor._PROTOCAL,self.sessionMan) if err then logger.log(logger.ERR, format("[SuProxy] start session fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) return end self.processor.ctx.clientIP=ngx.var.remote_addr self.processor.ctx.clientPort=ngx.var.remote_port self.processor.ctx.srvIP=upstream.ip self.processor.ctx.srvPort=upstream.port self.processor.ctx.srvID=upstream.id self.processor.ctx.srvGID=upstream.gid self.processor.ctx.connTime=ngx.time() session.ctx=self.processor.ctx self.session=session self.OnConnectEvent:trigger({clientIP=session.ctx.clientIP,clientPort=session.ctx.clientPort,srvIP=session.ctx.srvIP,srvPort=session.ctx.srvPort}) - while true do --todo: sessionMan should notify session change if not self.session:valid(self.session) then self.processor:sessionInvalid(self.session) else self.session.uptime=ngx.time() end logger.log(logger.DEBUG,"client --> proxy start process") - buf, err, partial = self.processor:processUpRequest(self.standalone) - if err then - logger.log(logger.ERR, format("[SuProxy] processUpRequest fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) - break - end - --if in standalone mode, don't forward - if not self.standalone and buf then - local _, err = self.processor:sendUp(buf) - if err then - logger.log(logger.ERR, format("[SuProxy] forward to upstream fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) - break - end - end - end self:shutdown(upstream) -end - local function _dwn(self) local upstream=self.upstream - -- proxy response to client - local buf, err, partial - while true do logger.log(logger.DEBUG,"server --> proxy start process") - buf, err, partial = self.processor:processDownRequest(self.standalone) - if err then - logger.log(logger.ERR, format("[SuProxy] processDownRequest fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) - break - end - if buf then - local _, err = self.processor:sendDown(buf) - if err then - logger.log(logger.ERR, format("[SuProxy] forward to downstream fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) - break - end - end - end self:shutdown(upstream) -end function _M:c2pRead(length) local bytes,err,partial= self.c2pSock:receive(length) logger.logWithTitle(logger.DEBUG,"c2pRead",(bytes and bytes:hex16F() or "")) return bytes,err,partial end function _M:p2sRead(length) local bytes,err,partial= self.p2sSock:receive(length) logger.logWithTitle(logger.DEBUG,"p2sRead",(bytes and bytes:hex16F() or "")) return bytes,err,partial end function _M:c2pSend(bytes) logger.logWithTitle(logger.DEBUG,"c2pSend",(bytes and bytes:hex16F() or "")) return self.c2pSock:send(bytes) end function _M:p2sSend(bytes) logger.logWithTitle(logger.DEBUG,"p2sSend",(bytes and bytes:hex16F() or "")) return self.p2sSock:send(bytes) end -function _M:run() --this while is to ensure _cleanup will always be executed - while true do local upstream - if(not self.standalone) then while true do upstream=self.balancer:getBest() if not upstream then logger.log(logger.ERR, format("[SuProxy] failed to get avaliable upstream")) break end - local ok, err = self.p2sSock:connect(upstream.ip, upstream.port) - if not ok then - logger.log(logger.ERR, format("[SuProxy] failed to connect to proxy upstream: %s:%s, err:%s", upstream.ip, upstream.port, err)) - self.balancer:blame(upstream) - else logger.log(logger.INFO, format("[SuProxy] connect to proxy upstream: %s:%s", upstream.ip, upstream.port)) self.upstream=upstream break end end - end if not self.standalone and not upstream then break end --_singThreadRun(self) - local co_upl = spawn(_upl,self) - if(not self.standalone) then - local co_dwn = spawn(_dwn,self) - wait(co_dwn) - end - wait(co_upl) - break - end - _cleanup(self) -end function _M:shutdown() if self.session then --self.processor:sessionInvalid(self.session) local err=self.session:kill(self.session) if err then logger.log(logger.ERR, format("[SuProxy] kill session fail: %s:%s, err:%s", self.upstream.ip, self.upstream.port, err)) end end _cleanup(self) end - -return _M +local sub = string.sub +local byte = string.byte +local format = string.format +local tcp = ngx.socket.tcp +local setmetatable = setmetatable +local spawn = ngx.thread.spawn +local wait = ngx.thread.wait +local logger = require "strmproxy.utils.compatibleLog" +local ses = require "strmproxy.session.session" +local cjson = require "cjson" +local event = require "strmproxy.utils.event" +local balancer = require "strmproxy.balancer.balancer" + +local _M = {} + +_M._VERSION = '0.01' + +function _M:new(upstreams, processor, options) + local o = {} + options = options or {} + options.c2pConnTimeout = options.c2pConnTimeout or 10000 + options.c2pSendTimeout = options.c2pSendTimeout or 10000 + options.c2pReadTimeout = options.c2pReadTimeout or 3600000 + options.p2sConnTimeout = options.p2sConnTimeout or 10000 + options.p2sSendTimeout = options.p2sSendTimeout or 10000 + options.p2sReadTimeout = options.p2sReadTimeout or 3600000 + if ngx.var.sockettype == "udp" then + options.udp = true + end + local c2pSock, err = ngx.req.socket(options.raw) + if not c2pSock then + return nil, err + end + if not options.udp then + c2pSock:settimeouts(options.c2pConnTimeout, options.c2pSendTimeout, options.c2pReadTimeout) + end + local standalone = false + if (not upstreams) then + logger.err(format(">[new] no upstream specified, Proxy will run in standalone mode")) + standalone = true + end + local p2sSock = nil + if (not standalone) then + if not options.udp then + p2sSock, err = tcp() + else + p2sSock, err = ngx.socket.udp() + end + if not p2sSock then + return nil, err + end + if not options.udp then + p2sSock:settimeouts(options.p2sConnTimeout, options.p2sSendTimeout, options.p2sReadTimeout) + end + end + --add default receive-then-forward processor + if (not processor and not standalone) then + processor = {} + processor.processUpRequest = function(self) + local data, err, partial = self.channel:c2pRead(1024 * 10) + --real error happend or timeout + if not data and not partial and err then return nil, err end + if (data and not err) then + return data + else + return partial + end + end + processor.processDownRequest = function(self) + local data, err, partial = self.channel:p2sRead(1024 * 10) + --real error happend or timeout + if not data and not partial and err then return nil, err end + if (data and not err) then + return data + else + return partial + end + end + end + --add default echo processor if proxy in standalone mode + if (not processor and standalone) then + logger.err(format(">[new] not processor and standalone")) + processor = {} + processor.processUpRequest = function(self) + local data, err, partial = self.channel:c2pRead(1024 * 10) + --real error happend or timeout + if not data and not partial and err then return nil, err end + local echodata = "" + if (data and not err) then + echodata = data + else + echodata = partial + end + logger.inf( echodata) + local _, err = self.channel:c2pSend(echodata) + logger.err( partial) + end + end + + local upForwarder = function(self, data) + if data then return self.channel:p2sSend(data) end + end + + local downForwarder = function(self, data) + if data then return self.channel:c2pSend(data) end + end + + --add default upforwarder + processor.sendUp = processor.sendUp or upForwarder + --add default downforwarder + processor.sendDown = processor.sendDown or downForwarder + + processor.ctx = processor.ctx or {} + + local sessionInvalidHandler = function(self, session) + logger.dbg(">[new] session closed") + self:shutdown() + end + --set default session invalid handler + processor.sessionInvalid = processor.sessionInvalid or sessionInvalidHandler + --set AuthSuccessEvent handler + if processor.AuthSuccessEvent then + processor.AuthSuccessEvent:addHandler(o, function(self, source, username) + if self.session and username then self.session.uid = username end + end) + end + --update ctx info to session + if processor.ContextUpdateEvent then + processor.ContextUpdateEvent:addHandler(o, function(self, source, ctx) + if ctx and self.session then + self.session.ctx = ctx + end + end) + end + o.p2sSock = p2sSock + o.c2pSock = c2pSock + o.processor = processor + o.balancer = upstreams.getBest and upstreams or balancer:new(upstreams) + o.standalone = standalone + o.OnConnectEvent = event:new(o, "OnConnectEvent") + o.sessionMan = options.sessionMan or ses:newDoNothing() + o.elapsed_start=0 + o.elapsed_end=0 + o.elapsed_time=0 + o.udp = options.udp or false + setmetatable(o, { __index = self }) + processor.channel = o + return o +end + +local function _cleanup(self) + logger.dbg(">[_cleanup] clean up executed") + + -- make sure buffers are clean + if not self.udp then + ngx.flush(true) + end + + local p2sSock = self.p2sSock + local c2pSock = self.c2pSock + if p2sSock ~= nil then + if p2sSock.shutdown then + p2sSock:shutdown("send") + end + if p2sSock.close ~= nil then + if not self.udp then + local ok, err = p2sSock:setkeepalive() + if not ok then + logger.err(format(">[_cleanup] Failed to p2sSock:setkeepalive()")) + end + end + end + end + + if c2pSock ~= nil then + if c2pSock.shutdown then + c2pSock:shutdown("send") + end + if c2pSock.close ~= nil then + local ok, err = c2pSock:close() + if not ok then + logger.err(format(">[_cleanup] Failed to c2pSock:close()")) + end + end + end +end + +local function _upl(self) + -- proxy client request to server + local upstream = self.upstream + local buf, err, partial + + local session, err = ses:new(self.processor._PROTOCAL, self.sessionMan) + if err then + logger.err(format(">[_upl] start session fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) + return + end + self.processor.ctx.clientIP = ngx.var.remote_addr + self.processor.ctx.clientPort = ngx.var.remote_port + self.processor.ctx.srvIP = upstream.ip + self.processor.ctx.srvPort = upstream.port + self.processor.ctx.srvID = upstream.id + self.processor.ctx.srvGID = upstream.gid + self.processor.ctx.connTime = ngx.time() + session.ctx = self.processor.ctx + self.session = session + self.OnConnectEvent:trigger({ + clientIP = session.ctx.clientIP, + clientPort = session.ctx.clientPort, + srvIP = session.ctx.srvIP, + srvPort = session.ctx.srvPort + }) + + logger.inf(">[_upl] session processor type: ", self.session.stype) + while true do + --todo: sessionMan should notify session change + if not self.session:valid(self.session) then + self.processor:sessionInvalid(self.session) + else + self.session.uptime = ngx.time() + end + logger.inf(">[_upl] client --> proxy start process") + buf, err, partial = self.processor:processUpRequest(self.standalone) + if err then + logger.err(format(">[_upl] processUpRequest fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) + break + end + --if in standalone mode, don't forward + if not self.standalone and buf then + logger.inf(">_upl()<-sendUp() - self.channel.p2sSend") + local _, err = self.processor:sendUp(buf) + if err then + logger.err(format(">[_upl] forward to upstream fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) + break + end + end + end + self:shutdown(upstream) +end + +local function _dwn(self) + -- logger.inf(">[_dwn] session processor type: ", self.session.stype) + local upstream = self.upstream + -- proxy response to client + local buf, err, partial + while true do + logger.inf(">[_dwn] server --> proxy start process") + buf, err, partial = self.processor:processDownRequest(self.standalone) + if err then + logger.err(format(">[_dwn] processDownRequest fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) + break + end + if buf then + logger.inf(">_dwn()<-sendDown() - self.channel.c2pSend") + local _, err = self.processor:sendDown(buf) + if err then + logger.err(format(">[_dwn] forward to downstream fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) + break + end + end + end + self:shutdown(upstream) +end + +function _M:c2pRead(length) + -- print(debug.traceback()) + logger.inf(">[c2pPRead] c2pSock:receive (" .. (length == "*l" and length .. ", it means reading a line)" or length .. " bytes)")) + local bytes, err, partial = self.c2pSock:receive(length) + -- logger.dbgWithTitle("c2pRead",(bytes and bytes:hex32F() or "")) + return bytes, err, partial +end + +function _M:p2sRead(length) + -- print(debug.traceback()) + logger.inf(">[p2sRead] p2sSock:receive (" .. length .. " bytes)") + local bytes, err, partial = self.p2sSock:receive(length) + -- logger.dbgWithTitle("p2sRead",(bytes and bytes:hex32F() or "")) + return bytes, err, partial +end + +function _M:c2pSend(bytes) + -- print(debug.traceback()) + logger.inf(">[c2pSend] c2pSock:send") + -- logger.dbgWithTitle("c2pSend:send",(bytes and bytes:hex32F() or "")) + return self.c2pSock:send(bytes) +end + +function _M:p2sSend(bytes) + -- print(debug.traceback()) + logger.inf(">[p2sSend] p2sSock:send") + -- logger.dbgWithTitle("p2sSend:send",(bytes and bytes:hex32F() or "")) + return self.p2sSock:send(bytes) +end + +function _M:run() + logger.inf(format(">self.standalone: %s", tostring(self.standalone))) + --this while is to ensure _cleanup will always be executed + while true do + local upstream + if (not self.standalone) then + while true do + upstream = self.balancer:getBest() + if not upstream then + logger.err(format(">[run] failed to get avaliable upstream")) + break + end + + local max_attempts = 3 + local ok, err + -- Retry up to 3 times + for attempts = 0, max_attempts do + if not self.udp then + ok, err = self.p2sSock:connect(upstream.ip, upstream.port) + else + ok, err = self.p2sSock:setpeername(upstream.ip, upstream.port) + end + if not ok then + logger.err(format(">[run] failed to connect to proxy upstream: %s:%s, err:%s", upstream.ip, upstream.port, err)) + self.balancer:blame(upstream) + if attempts < max_attempts then + ngx.sleep(1) + else + logger.err(format(">[run] Attempts exceeded, connection retry terminated")) + break + end + else + if upstream.ssl then + ok, err = self.p2sSock:sslhandshake(nil, upstream.ip, false) + if not ok then + logger.err(format(">[run] failed to ssl handshake: %s:%s, err:%s", upstream.ip, upstream.port, err)) + end + end + break + end + end + if not ok then + logger.err(format(">[run] failed to connect to proxy upstream after %d attempts: %s:%s", max_attempts, upstream.ip, upstream.port)) + break + else + logger.inf(format(">[run] connected to proxy upstream: %s:%s", upstream.ip, upstream.port)) + self.upstream = upstream + break + end + end + end + if not self.standalone and not self.upstream then + logger.err(format(">[run] standalone: %s, upstream: %s:%s", self.standalone, upstream.ip, upstream.port)) + break + end + --_singThreadRun(self) + logger.inf(">[run]::: SPAWN _upl() :::") + local co_upl = spawn(_upl, self) + if (not self.standalone) then + logger.inf(">[run]::: SPAWN _dwn() :::") + local co_dwn = spawn(_dwn, self) + logger.inf(">[run]::: WAIT _dwn() :::") + wait(co_dwn) + end + logger.inf(">[run]::: WAIT _upl() :::") + wait(co_upl) + break + end + _cleanup(self) +end + +function _M:shutdown() + if self.session then + --self.processor:sessionInvalid(self.session) + local err = self.session:kill(self.session) + if err then + logger.err(format(">[shutdown] kill session fail: %s:%s, err:%s", self.upstream.ip, self.upstream.port, err)) + end + end + _cleanup(self) +end + +return _M diff --git a/ws.lua b/ws.lua index 97198c7..a2117f0 100644 --- a/ws.lua +++ b/ws.lua @@ -20,9 +20,10 @@ _M._PROTOCOL ='ws' local protocolPacket=require ("strmproxy.".. _M._PROTOCOL .. ".packets") -function _M.new(self) +function _M.new(self, handshake_request) local o = setmetatable({},{__index=self}) o.c2p_stage = "INIT" + o.c2p_handshake_request = handshake_request o.p2s_stage = "INIT" o.HandshakeRequestEvent=event:newReturnEvent(o,"HandshakeRequestEvent") @@ -33,18 +34,18 @@ function _M.new(self) local parser=require ("strmproxy.".. _M._PROTOCOL ..".parser"):new() o.C2PParser = parser.C2PParser - o.C2PParser.events.TextEvent:addHandler(o, self.OnUpTextEvent) - o.C2PParser.events.BinaryEvent:addHandler(o, self.OnUpBinaryEvent) - o.C2PParser.events.CloseEvent:addHandler(o, self.OnUpCloseEvent) - o.C2PParser.events.PingEvent:addHandler(o, self.OnUpPingEvent) - o.C2PParser.events.PongEvent:addHandler(o, self.OnUpPongEvent) + o.C2PParser.events.TextEvent:addHandler(o, self.OnUploadEvent) + o.C2PParser.events.BinaryEvent:addHandler(o, self.OnUploadEvent) + o.C2PParser.events.CloseEvent:addHandler(o, self.OnUploadEvent) + o.C2PParser.events.PingEvent:addHandler(o, self.OnUploadEvent) + o.C2PParser.events.PongEvent:addHandler(o, self.OnUploadEvent) o.S2PParser = parser.S2PParser - o.S2PParser.events.TextEvent:addHandler(o, self.OnDownTextEvent) - o.S2PParser.events.BinaryEvent:addHandler(o, self.OnDownBinaryEvent) - o.S2PParser.events.CloseEvent:addHandler(o, self.OnDownCloseEvent) - o.S2PParser.events.PingEvent:addHandler(o, self.OnDownPingEvent) - o.S2PParser.events.PongEvent:addHandler(o, self.OnDownPongEvent) + o.S2PParser.events.TextEvent:addHandler(o, self.OnDownloadEvent) + o.S2PParser.events.BinaryEvent:addHandler(o, self.OnDownloadEvent) + o.S2PParser.events.CloseEvent:addHandler(o, self.OnDownloadEvent) + o.S2PParser.events.PingEvent:addHandler(o, self.OnDownloadEvent) + o.S2PParser.events.PongEvent:addHandler(o, self.OnDownloadEvent) return o end @@ -148,44 +149,12 @@ local function readHandshakeResponse(self, sock) end ---------------parser event handlers---------------------- -function _M:OnUpTextEvent(source, packet, up) - self.FrameEvent:trigger({type="text", payload=packet.payload, up="up"}, self.ctx) +function _M:OnUploadEvent(source, packet) + self.FrameEvent:trigger({packet=packet, up=true}, self.ctx) end -function _M:OnUpBinaryEvent(source, packet, up) - self.FrameEvent:trigger({type="binary", payload=packet.payload, up="up"}, self.ctx) -end - -function _M:OnUpCloseEvent(source, packet, up) - self.FrameEvent:trigger({type="close", payload=packet.payload, up="up"}, self.ctx) -end - -function _M:OnUpPingEvent(source, packet, up) - self.FrameEvent:trigger({type="ping", payload=packet.payload, up="up"}, self.ctx) -end - -function _M:OnUpPongEvent(source, packet, up) - self.FrameEvent:trigger({type="pong", payload=packet.payload, up="up"}, self.ctx) -end - -function _M:OnDownTextEvent(source, packet, up) - self.FrameEvent:trigger({type="text", payload=packet.payload, up="down"}, self.ctx) -end - -function _M:OnDownBinaryEvent(source, packet, up) - self.FrameEvent:trigger({type="binary", payload=packet.payload, up="down"}, self.ctx) -end - -function _M:OnDownCloseEvent(source, packet, up) - self.FrameEvent:trigger({type="close", payload=packet.payload, up="down"}, self.ctx) -end - -function _M:OnDownPingEvent(source, packet, up) - self.FrameEvent:trigger({type="ping", payload=packet.payload, up="down"}, self.ctx) -end - -function _M:OnDownPongEvent(source, packet, up) - self.FrameEvent:trigger({type="pong", payload=packet.payload, up="down"}, self.ctx) +function _M:OnDownloadEvent(source, packet) + self.FrameEvent:trigger({packet=packet, up=false}, self.ctx) end ---------------receive and parse packet---------------------- @@ -324,14 +293,21 @@ end function _M.processUpRequest(self) - if (self.c2p_stage == "INIT") then + if self.c2p_stage == "INIT" then - logger.dbg("Websocket>: readHandshakeRequest()") - local allBytes, handshake, err = readHandshakeRequest(self, self.channel.c2pSock) - if err then - logger.err("Websocket>: -- Failed to read handshake request: ", err) - return nil,err + local allBytes, handshake, err + + if self.c2p_handshake_request then + allBytes = self.c2p_handshake_request + handshake = parse_handshake(allBytes) + else + allBytes, handshake, err = readHandshakeRequest(self, self.channel.c2pSock) + if err then + logger.err("Websocket>: -- Failed to read handshake request: ", err) + return nil,err + end end + logger.dbg("Websocket>: readHandshakeRequest()") self.c2p_stage = "HANDSHAKE" self.HandshakeRequestEvent:trigger(handshake, self.ctx) return allBytes @@ -360,7 +336,7 @@ end function _M.processDownRequest(self) - if (self.p2s_stage == "INIT") then + if self.p2s_stage == "INIT" then logger.dbg("Websocket>: readHandshakeResponse()") local allBytes, handshake, err = readHandshakeResponse(self, self.channel.p2sSock) diff --git a/ws/packets.lua b/ws/packets.lua index 7530074..851deaa 100644 --- a/ws/packets.lua +++ b/ws/packets.lua @@ -5,7 +5,6 @@ local extends=tableUtils.extends local orderTable=tableUtils.OrderedTable local bit = require "bit" local ffi = require "ffi" -local zlib = require("zlib") local byte = string.byte local char = string.char @@ -58,7 +57,8 @@ _M.Base = { self.opcode = band(fst, 0x0f) self.fin = band(fst, 0x80) ~= 0 self.mask = band(snd, 0x80) ~= 0 - + self.rsv1 = band(fst, 0x40) ~= 0 + if self.mask then local payload_len = #data - 4 -- TODO string.buffer optimizations diff --git a/wshandler.lua b/wshandler.lua index 32059d0..ddb06d1 100644 --- a/wshandler.lua +++ b/wshandler.lua @@ -54,16 +54,17 @@ local function OnHandshakeResponseEvent(context, source, headers) wsLog(log) end -local function OnFrameEvent(context, source, packet) +local function OnFrameEvent(context, source, frame) local log - if (packet.up == "up") then + local packet = frame.packet + if frame.up then log = format('['.._M._PROTOCOL .. "] %s:%s sent \t\t[%s] %s\r\n", source.ctx.clientIP, source.ctx.clientPort, - packet.type, packet.payload) + packet.opcode, packet.payload) else log = format('['.._M._PROTOCOL .. "] received from %s:%s \t[%s] %s\r\n", source.ctx.srvIP, source.ctx.srvPort, - packet.type, packet.payload) + packet.opcode, packet.payload) end logger.dbg("[" .. _M._PROTOCOL .. " ] ", log)