diff --git a/cmpctircd/IRCd.cs b/cmpctircd/IRCd.cs index 41f2bd6..bc132f3 100644 --- a/cmpctircd/IRCd.cs +++ b/cmpctircd/IRCd.cs @@ -9,6 +9,8 @@ using cmpctircd.Modes; public class IRCd { + private readonly ISocketListenerFactory socketListenerFactory; + private readonly ISocketConnectorFactory socketConnectorFactory; private readonly IList Listeners = new List(); public readonly IList Connectors = new List(); public PacketManager PacketManager { get; } @@ -50,9 +52,11 @@ public class IRCd { public List Clients => ClientLists.SelectMany(clientList => clientList).ToList(); public List Servers => ServerLists.SelectMany(serverList => serverList).ToList(); - public IRCd(Log log, CmpctConfigurationSection config, IServiceProvider services) { + public IRCd(Log log, CmpctConfigurationSection config, IServiceProvider services, ISocketListenerFactory socketListenerFactory, ISocketConnectorFactory socketConnectorFactory) { this.Log = log; this.Config = config; + this.socketListenerFactory = socketListenerFactory; + this.socketConnectorFactory = socketConnectorFactory; // Interpret the ConfigData SID = config.SID; @@ -99,7 +103,7 @@ public void Run() { PacketManager.Load(); foreach(var listener in Config.Sockets.OfType()) { - SocketListener sl = new SocketListener(this, listener); + SocketListener sl = socketListenerFactory.CreateSocketListener(this, listener); Log.Info($"==> Listening on: {listener.Host}:{listener.Port} ({listener.Type}) ({(listener.IsTls ? "TLS" : "Plain" )})"); Listeners.Add(sl); @@ -110,7 +114,7 @@ public void Run() { if (server.IsOutbound) { // tag with outbound="true" // We want to connect out to this server, not have them connect to us - var sc = new SocketConnector(this, server); + var sc = socketConnectorFactory.CreateSocketConnector(this, server); Log.Info($"==> Connecting to: {server.Destination}:{server.Port} ({server.Host}) ({(server.IsTls ? "TLS" : "Plain" )})"); Connectors.Add(sc); diff --git a/cmpctircd/ISocketConnectorFactory.cs b/cmpctircd/ISocketConnectorFactory.cs new file mode 100644 index 0000000..d62ccf7 --- /dev/null +++ b/cmpctircd/ISocketConnectorFactory.cs @@ -0,0 +1,7 @@ +using cmpctircd.Configuration; + +namespace cmpctircd { + public interface ISocketConnectorFactory { + SocketConnector CreateSocketConnector(IRCd ircd, ServerElement config); + } +} \ No newline at end of file diff --git a/cmpctircd/ISocketListenerFactory.cs b/cmpctircd/ISocketListenerFactory.cs new file mode 100644 index 0000000..824150b --- /dev/null +++ b/cmpctircd/ISocketListenerFactory.cs @@ -0,0 +1,7 @@ +using cmpctircd.Configuration; + +namespace cmpctircd { + public interface ISocketListenerFactory { + SocketListener CreateSocketListener(IRCd ircd, SocketElement config); + } +} \ No newline at end of file diff --git a/cmpctircd/Program.cs b/cmpctircd/Program.cs index 5d9c0af..7176627 100644 --- a/cmpctircd/Program.cs +++ b/cmpctircd/Program.cs @@ -28,6 +28,8 @@ static IHostBuilder CreateHostBuilder(string[] args) { services.AddScoped(); services.AddScoped(sp => sp.GetRequiredService().Sender as Client); services.AddScoped(sp => sp.GetRequiredService().Sender as Server); + services.AddTransient(); + services.AddTransient(); services.AddHostedService(); }); } diff --git a/cmpctircd/SocketConnector.cs b/cmpctircd/SocketConnector.cs index 7c252e9..ed20e8a 100644 --- a/cmpctircd/SocketConnector.cs +++ b/cmpctircd/SocketConnector.cs @@ -9,14 +9,15 @@ namespace cmpctircd { public class SocketConnector : SocketListener { - + private readonly Log log; public ServerElement ServerInfo; public bool Connected; private TcpClient tc; private NetworkStream stream; - public SocketConnector(IRCd ircd, ServerElement info) : base(ircd, info) { + public SocketConnector(Log log, IRCd ircd, ServerElement info) : base(log, ircd, info) { + this.log = log ?? throw new ArgumentNullException(nameof(log)); ServerInfo = info; } @@ -38,7 +39,7 @@ public async Task Connect() { await tc.ConnectAsync(Info.Host.ToString(), Info.Port); stream = tc.GetStream(); } catch (SocketException) { - _ircd.Log.Warn($"Unable to connect to server {Info.Host.ToString()}:{Info.Port}"); + log.Warn($"Unable to connect to server {Info.Host.ToString()}:{Info.Port}"); return; } diff --git a/cmpctircd/SocketConnectorFactory.cs b/cmpctircd/SocketConnectorFactory.cs new file mode 100644 index 0000000..df3c25e --- /dev/null +++ b/cmpctircd/SocketConnectorFactory.cs @@ -0,0 +1,16 @@ +using cmpctircd.Configuration; +using System; + +namespace cmpctircd { + public class SocketConnectorFactory : ISocketConnectorFactory { + private readonly Log log; + + public SocketConnectorFactory(Log log) { + this.log = log ?? throw new ArgumentNullException(nameof(log)); + } + + public SocketConnector CreateSocketConnector(IRCd ircd, ServerElement config) { + return new SocketConnector(log, ircd, config); + } + } +} diff --git a/cmpctircd/SocketListener.cs b/cmpctircd/SocketListener.cs index 19bdfa6..7ec3d5f 100644 --- a/cmpctircd/SocketListener.cs +++ b/cmpctircd/SocketListener.cs @@ -12,6 +12,7 @@ namespace cmpctircd { public class SocketListener { + private readonly Log log; protected IRCd _ircd; private Boolean _started = false; private TcpListener _listener = null; @@ -25,7 +26,8 @@ public class SocketListener { public int ServerCount = 0; public int AuthServerCount = 0; - public SocketListener(IRCd ircd, SocketElement info) { + public SocketListener(Log log, IRCd ircd, SocketElement info) { + this.log = log ?? throw new ArgumentNullException(nameof(log)); this._ircd = ircd; this.Info = info; _listener = new TcpListener(info.Host, info.Port); @@ -43,7 +45,7 @@ public virtual void Bind() { } public virtual void Stop() { if (_started) { - _ircd.Log.Debug($"Shutting down listener [IP: {Info.Host}, Port: {Info.Port}, TLS: {Info.IsTls}]"); + log.Debug($"Shutting down listener [IP: {Info.Host}, Port: {Info.Port}, TLS: {Info.IsTls}]"); _listener.Stop(); _started = false; } @@ -61,7 +63,7 @@ public async Task ListenToClients() { TcpClient tc = await _listener.AcceptTcpClientAsync(); HandleClientAsync(tc); // this should split off execution } catch(Exception e) { - _ircd.Log.Error($"Exception in ListenToClients(): {e.ToString()}"); + log.Error($"Exception in ListenToClients(): {e.ToString()}"); } } } @@ -72,7 +74,7 @@ protected async Task HandshakeIfNeededAsync(TcpClient tc, Stream stream) try { stream = await HandshakeTlsAsServerAsync(tc); } catch (Exception e) { - _ircd.Log.Debug($"Exception in {nameof(HandshakeTlsAsServerAsync)}: {e}"); + log.Debug($"Exception in {nameof(HandshakeTlsAsServerAsync)}: {e}"); tc.Close(); } } @@ -189,7 +191,7 @@ public async Task HandshakeTlsAsClient(TcpClient tc, string host, boo if (verifyCert) { stream = new SslStream(tc.GetStream(), true); } else { - _ircd.Log.Warn($"[SERVER] Connecting out to server {host} with TLS verification disabled: this is dangerous!"); + log.Warn($"[SERVER] Connecting out to server {host} with TLS verification disabled: this is dangerous!"); stream = new SslStream(tc.GetStream(), true, (sender, certificate, chain, sslPolicyErrors) => true); } diff --git a/cmpctircd/SocketListenerFactory.cs b/cmpctircd/SocketListenerFactory.cs new file mode 100644 index 0000000..bf9575b --- /dev/null +++ b/cmpctircd/SocketListenerFactory.cs @@ -0,0 +1,16 @@ +using cmpctircd.Configuration; +using System; + +namespace cmpctircd { + public class SocketListenerFactory : ISocketListenerFactory { + private readonly Log log; + + public SocketListenerFactory(Log log) { + this.log = log ?? throw new ArgumentNullException(nameof(log)); + } + + public SocketListener CreateSocketListener(IRCd ircd, SocketElement config) { + return new SocketListener(log, ircd, config); + } + } +}