diff --git a/src/csharp/Microsoft.Spark.UnitTest/SparkFixture.cs b/src/csharp/Microsoft.Spark.UnitTest/SparkFixture.cs index 06c9a3fe2..c7aed2d4d 100644 --- a/src/csharp/Microsoft.Spark.UnitTest/SparkFixture.cs +++ b/src/csharp/Microsoft.Spark.UnitTest/SparkFixture.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Net; using Microsoft.Spark.Interop; using Microsoft.Spark.Interop.Ipc; using Moq; @@ -27,7 +28,7 @@ public SparkFixture() var mockJvmBridgeFactory = new Mock(); mockJvmBridgeFactory - .Setup(m => m.Create(It.IsAny())) + .Setup(m => m.Create(It.IsAny(), It.IsAny())) .Returns(MockJvm.Object); SparkEnvironment.JvmBridgeFactory = mockJvmBridgeFactory.Object; diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadProcessorTests.cs b/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadProcessorTests.cs index b2edba995..fb93df55d 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadProcessorTests.cs +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadProcessorTests.cs @@ -71,11 +71,11 @@ public void TestClosedStreamWithSocket() PayloadWriter payloadWriter = new PayloadWriterFactory().Create(); Payload payload = TestData.GetDefaultPayload(); - using var serverListener = new DefaultSocketWrapper(); + using var serverListener = new DefaultSocketWrapper(IPAddress.Loopback); serverListener.Listen(); var port = (serverListener.LocalEndPoint as IPEndPoint).Port; - using var clientSocket = new DefaultSocketWrapper(); + using var clientSocket = new DefaultSocketWrapper(IPAddress.Loopback); clientSocket.Connect(IPAddress.Loopback, port, null); using (ISocketWrapper serverSocket = serverListener.Accept()) diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/TaskRunnerTests.cs b/src/csharp/Microsoft.Spark.Worker.UnitTest/TaskRunnerTests.cs index f495e142b..23a1a5a46 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/TaskRunnerTests.cs +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/TaskRunnerTests.cs @@ -16,11 +16,11 @@ public class TaskRunnerTests [Fact] public void TestTaskRunner() { - using var serverListener = new DefaultSocketWrapper(); + using var serverListener = new DefaultSocketWrapper(IPAddress.Loopback); serverListener.Listen(); var port = (serverListener.LocalEndPoint as IPEndPoint).Port; - var clientSocket = new DefaultSocketWrapper(); + var clientSocket = new DefaultSocketWrapper(IPAddress.Loopback); clientSocket.Connect(IPAddress.Loopback, port, null); PayloadWriter payloadWriter = new PayloadWriterFactory().Create(); diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs index d86fd7305..09b5a8b0b 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs @@ -155,7 +155,8 @@ internal void Run(ISocketWrapper listener) /// private void Run() { - Run(SocketFactory.CreateSocket()); + IPAddress dotnetCallbackServerIpAddress = SparkEnvironment.ConfigurationService.GetCallbackServerIPAddress(); + Run(SocketFactory.CreateSocket(dotnetCallbackServerIpAddress)); } /// diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/IJvmBridgeFactory.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/IJvmBridgeFactory.cs index 428565527..b8be2918c 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/IJvmBridgeFactory.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/IJvmBridgeFactory.cs @@ -2,10 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Net; + namespace Microsoft.Spark.Interop.Ipc { internal interface IJvmBridgeFactory { IJvmBridge Create(int portNumber); + + IJvmBridge Create(IPAddress ip, int portNumber); } } diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index 6c8d61840..2d187b12d 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -41,19 +41,24 @@ internal sealed class JvmBridge : IJvmBridge new ConcurrentQueue(); private readonly ILoggerService _logger = LoggerServiceFactory.GetLogger(typeof(JvmBridge)); + private readonly IPAddress _ipAddress; private readonly int _portNumber; private readonly JvmThreadPoolGC _jvmThreadPoolGC; private readonly bool _isRunningRepl; - internal JvmBridge(int portNumber) + internal JvmBridge(int portNumber): this(IPAddress.Loopback, portNumber) + { + } + + internal JvmBridge(IPAddress ipAddress, int portNumber) { if (portNumber == 0) { throw new Exception("Port number is not set."); } - + _ipAddress = ipAddress; _portNumber = portNumber; - _logger.LogInfo($"JvMBridge port is {portNumber}"); + _logger.LogInfo($"JvMBridge IP is {_ipAddress} port is {_portNumber}"); _jvmThreadPoolGC = new JvmThreadPoolGC( _logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval, _processId); @@ -83,8 +88,9 @@ private ISocketWrapper GetConnection() _socketSemaphore.Wait(); if (!_sockets.TryDequeue(out ISocketWrapper socket)) { - socket = SocketFactory.CreateSocket(); - socket.Connect(IPAddress.Loopback, _portNumber); + IPEndPoint dotnetBackendIPEndpoint = SparkEnvironment.ConfigurationService.GetBackendIPEndpoint(); + socket = SocketFactory.CreateSocket(dotnetBackendIPEndpoint.Address); + socket.Connect(_ipAddress, _portNumber); } return socket; diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridgeFactory.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridgeFactory.cs index 9c9f4ca43..f8bee0e65 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridgeFactory.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridgeFactory.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Net; + namespace Microsoft.Spark.Interop.Ipc { internal class JvmBridgeFactory : IJvmBridgeFactory @@ -10,5 +12,10 @@ public IJvmBridge Create(int portNumber) { return new JvmBridge(portNumber); } + + public IJvmBridge Create(IPAddress ipAddress, int portNumber) + { + return new JvmBridge(ipAddress, portNumber); + } } } diff --git a/src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs b/src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs index cf2c2dc0e..893318954 100644 --- a/src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs +++ b/src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Net; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Services; @@ -70,8 +71,9 @@ public static IJvmBridge JvmBridge { get { + IPEndPoint jvmBackendEndPoint = ConfigurationService.GetBackendIPEndpoint(); return s_jvmBridge ??= - JvmBridgeFactory.Create(ConfigurationService.GetBackendPortNumber()); + JvmBridgeFactory.Create(jvmBackendEndPoint.Address, jvmBackendEndPoint.Port); } set { diff --git a/src/csharp/Microsoft.Spark/Network/DefaultSocketWrapper.cs b/src/csharp/Microsoft.Spark/Network/DefaultSocketWrapper.cs index 4f4a4491d..efcf9aab0 100644 --- a/src/csharp/Microsoft.Spark/Network/DefaultSocketWrapper.cs +++ b/src/csharp/Microsoft.Spark/Network/DefaultSocketWrapper.cs @@ -6,6 +6,7 @@ using System.IO; using System.Net; using System.Net.Sockets; +using Microsoft.Spark.Interop; using Microsoft.Spark.Services; using Microsoft.Spark.Utils; @@ -24,12 +25,12 @@ internal sealed class DefaultSocketWrapper : ISocketWrapper /// Default constructor that creates a new instance of DefaultSocket class which represents /// a traditional socket (System.Net.Socket.Socket). /// - /// This socket is bound to Loopback with port 0. + /// This socket is bound to provided IP address with port 0. /// - public DefaultSocketWrapper() : - this(new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + public DefaultSocketWrapper(IPAddress ipAddress) : + this(new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp)) { - _innerSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + _innerSocket.Bind(new IPEndPoint(ipAddress, 0)); } /// diff --git a/src/csharp/Microsoft.Spark/Network/SocketFactory.cs b/src/csharp/Microsoft.Spark/Network/SocketFactory.cs index 46c0f6c0f..e0a4a62ea 100644 --- a/src/csharp/Microsoft.Spark/Network/SocketFactory.cs +++ b/src/csharp/Microsoft.Spark/Network/SocketFactory.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Net; + namespace Microsoft.Spark.Network { /// @@ -17,7 +19,12 @@ internal static class SocketFactory /// public static ISocketWrapper CreateSocket() { - return new DefaultSocketWrapper(); + return new DefaultSocketWrapper(IPAddress.Loopback); + } + + public static ISocketWrapper CreateSocket(IPAddress ip) + { + return new DefaultSocketWrapper(ip); } } } diff --git a/src/csharp/Microsoft.Spark/RDD.cs b/src/csharp/Microsoft.Spark/RDD.cs index 8c23cc87a..c4821d158 100644 --- a/src/csharp/Microsoft.Spark/RDD.cs +++ b/src/csharp/Microsoft.Spark/RDD.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Net; +using Microsoft.Spark.Interop; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Network; using Microsoft.Spark.Utils; @@ -262,7 +263,8 @@ public IEnumerable Collect() { (int port, string secret) = CollectAndServe(); using ISocketWrapper socket = SocketFactory.CreateSocket(); - socket.Connect(IPAddress.Loopback, port, secret); + IPEndPoint dotnetBackendIPEndpoint = SparkEnvironment.ConfigurationService.GetBackendIPEndpoint(); + socket.Connect(dotnetBackendIPEndpoint.Address, port, secret); var collector = new RDD.Collector(); System.IO.Stream stream = socket.InputStream; diff --git a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs index 505868fa6..e431a1fcd 100644 --- a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using System.Net; using System.Runtime.InteropServices; using static System.Environment; using Microsoft.Spark.Utils; @@ -24,6 +25,8 @@ internal sealed class ConfigurationService : IConfigurationService internal const string WorkerVerDirEnvVarNameFormat = "DOTNET_WORKER_{0}_DIR"; private const string DotnetBackendPortEnvVarName = "DOTNETBACKEND_PORT"; + private const string DotnetBackendIPAddressEnvVarName = "DOTNET_SPARK_BACKEND_IP_ADDRESS"; + private const string DotnetCallbackServerIPAddressEnvVarName = "DOTNET_SPARK_CALLBACK_SERVER_IP_ADDRESS"; private const int DotnetBackendDebugPort = 5567; private const string DotnetNumBackendThreadsEnvVarName = "DOTNET_SPARK_NUM_BACKEND_THREADS"; @@ -99,21 +102,26 @@ public TimeSpan JvmThreadGCInterval !string.IsNullOrEmpty(GetEnvironmentVariable("DATABRICKS_RUNTIME_VERSION")); /// - /// Returns the port number for socket communication between JVM and CLR. + /// Returns the IP Endpoint for socket communication between JVM and CLR. /// - public int GetBackendPortNumber() + public IPEndPoint GetBackendIPEndpoint() { if (!int.TryParse( - GetEnvironmentVariable(DotnetBackendPortEnvVarName), + Environment.GetEnvironmentVariable(DotnetBackendPortEnvVarName), out int portNumber)) { _logger.LogInfo($"'{DotnetBackendPortEnvVarName}' environment variable is not set."); portNumber = DotnetBackendDebugPort; } - - _logger.LogInfo($"Using port {portNumber} for connection."); - - return portNumber; + string ipAddress = Environment.GetEnvironmentVariable(DotnetBackendIPAddressEnvVarName); + if (ipAddress == null) + { + _logger.LogInfo($"'{DotnetBackendIPAddressEnvVarName}' environment variable is not set."); + ipAddress = "127.0.0.1"; + } + _logger.LogInfo($"Using IP address {ipAddress} and port {portNumber} for connection."); + + return new IPEndPoint(IPAddress.Parse(ipAddress), portNumber); } /// @@ -131,6 +139,22 @@ public int GetNumBackendThreads() return numThreads; } + /// + /// Returns the IP address for socket communication between JVM and CallBack Server. + /// + public IPAddress GetCallbackServerIPAddress() + { + string ipAddress = Environment.GetEnvironmentVariable(DotnetCallbackServerIPAddressEnvVarName); + if (ipAddress == null) + { + _logger.LogInfo($"'{DotnetCallbackServerIPAddressEnvVarName}' environment variable is not set."); + ipAddress = "127.0.0.1"; + } + _logger.LogInfo($"Using IP address {ipAddress} for connection with Callback Server."); + + return IPAddress.Parse(ipAddress); + } + /// /// Returns the worker executable path. /// diff --git a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs index 0cf211272..447e4aac4 100644 --- a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Net; namespace Microsoft.Spark.Services { @@ -17,14 +18,19 @@ internal interface IConfigurationService TimeSpan JvmThreadGCInterval { get; } /// - /// The port number used for communicating with the .NET backend process. + /// Returns the max number of threads for socket communication between JVM and CLR. /// - int GetBackendPortNumber(); + int GetNumBackendThreads(); /// - /// Returns the max number of threads for socket communication between JVM and CLR. + /// The IP Endpoint used for communicating with the .NET backend process. /// - int GetNumBackendThreads(); + IPEndPoint GetBackendIPEndpoint(); + + /// + /// The IP address used for communicating with CallBack server. + /// + IPAddress GetCallbackServerIPAddress(); /// /// The full path to the .NET worker executable. diff --git a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs index e8e86f3e8..eba977aae 100644 --- a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs +++ b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs @@ -849,7 +849,8 @@ public IEnumerable ToLocalIterator(bool prefetchPartitions) Reference.Invoke("toPythonIterator", prefetchPartitions), true); using ISocketWrapper socket = SocketFactory.CreateSocket(); - socket.Connect(IPAddress.Loopback, port, secret); + IPEndPoint dotnetBackendIPEndpoint = SparkEnvironment.ConfigurationService.GetBackendIPEndpoint(); + socket.Connect(dotnetBackendIPEndpoint.Address, port, secret); foreach (Row row in new RowCollector().Collect(socket, server)) { yield return row; @@ -1077,15 +1078,18 @@ public int SemanticHash() => /// /// String name of function to call /// Arguments to the function - /// IEnumerable of Rows from Spark + /// private IEnumerable GetRows(string funcName, params object[] args) { (int port, string secret, _) = GetConnectionInfo(funcName, args); - using ISocketWrapper socket = SocketFactory.CreateSocket(); - socket.Connect(IPAddress.Loopback, port, secret); - foreach (Row row in new RowCollector().Collect(socket)) + IPEndPoint dotnetBackendIPEndpoint = SparkEnvironment.ConfigurationService.GetBackendIPEndpoint(); + using (ISocketWrapper socket = SocketFactory.CreateSocket()) { - yield return row; + socket.Connect(dotnetBackendIPEndpoint.Address, port, secret); + foreach (Row row in new RowCollector().Collect(socket)) + { + yield return row; + } } } diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index 1d8215d44..66766f43c 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -35,7 +35,7 @@ class DotnetBackend extends Logging { @volatile private[dotnet] var callbackClient: Option[CallbackClient] = None - def init(portNumber: Int): Int = { + def init(ipAddress: String, portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS) logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.") @@ -63,7 +63,7 @@ class DotnetBackend extends Logging { } }) - channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber)) + channelFuture = bootstrap.bind(new InetSocketAddress(ipAddress, portNumber)) channelFuture.syncUninterruptibly() channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala index 8cf73ddd0..6d3f13288 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala @@ -63,6 +63,7 @@ object DotnetRunner extends Logging { // In debug mode this runner will not launch a .NET process. val runInDebugMode = settings._1 @volatile var dotnetBackendPortNumber = settings._2 + val dotnetBackendIPAddress = sys.env.getOrElse("DOTNET_SPARK_BACKEND_IP_ADDRESS", "127.0.0.1") var dotnetExecutable = "" var otherArgs: Array[String] = null @@ -110,8 +111,9 @@ object DotnetRunner extends Logging { override def run() { // need to get back dotnetBackendPortNumber because if the value passed to init is 0 // the port number is dynamically assigned in the backend - dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber) - logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber") + dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendIPAddress, dotnetBackendPortNumber) + logInfo(s"IP address used by DotnetBackend is $dotnetBackendIPAddress and " + + s"Port number used is $dotnetBackendPortNumber") initialized.release() dotnetBackend.run() } diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index c6f528aee..b59ef8b76 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -34,7 +34,7 @@ class DotnetBackend extends Logging { @volatile private[dotnet] var callbackClient: Option[CallbackClient] = None - def init(portNumber: Int): Int = { + def init(ipAddress: String, portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS) logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.") @@ -62,7 +62,7 @@ class DotnetBackend extends Logging { } }) - channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber)) + channelFuture = bootstrap.bind(new InetSocketAddress(ipAddress, portNumber)) channelFuture.syncUninterruptibly() channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala index 9762a9aa7..8e3b37d4a 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala @@ -7,7 +7,8 @@ package org.apache.spark.deploy.dotnet import java.io.File -import java.net.URI +import java.lang.NumberFormatException +import java.net.{InetAddress, URI, UnknownHostException} import java.nio.file.attribute.PosixFilePermissions import java.nio.file.{FileSystems, Files, Paths} import java.util.Locale @@ -62,6 +63,7 @@ object DotnetRunner extends Logging { // In debug mode this runner will not launch a .NET process. val runInDebugMode = settings._1 @volatile var dotnetBackendPortNumber = settings._2 + val dotnetBackendIPAddress = sys.env.getOrElse("DOTNET_SPARK_BACKEND_IP_ADDRESS", "127.0.0.1") var dotnetExecutable = "" var otherArgs: Array[String] = null @@ -109,8 +111,9 @@ object DotnetRunner extends Logging { override def run() { // need to get back dotnetBackendPortNumber because if the value passed to init is 0 // the port number is dynamically assigned in the backend - dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber) - logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber") + dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendIPAddress, dotnetBackendPortNumber) + logInfo(s"IP address used by DotnetBackend is $dotnetBackendIPAddress and " + + s"Port number used is $dotnetBackendPortNumber") initialized.release() dotnetBackend.run() } @@ -274,9 +277,9 @@ object DotnetRunner extends Logging { if (args.length == 1) { portNumber = DEBUG_PORT } else if (args.length == 2) { - portNumber = Integer.parseInt(args(1)) + portNumber = Integer.parseInt(args(1)) + } } - } (runInDebugMode, portNumber) } diff --git a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index c6f528aee..b59ef8b76 100644 --- a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -34,7 +34,7 @@ class DotnetBackend extends Logging { @volatile private[dotnet] var callbackClient: Option[CallbackClient] = None - def init(portNumber: Int): Int = { + def init(ipAddress: String, portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS) logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.") @@ -62,7 +62,7 @@ class DotnetBackend extends Logging { } }) - channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber)) + channelFuture = bootstrap.bind(new InetSocketAddress(ipAddress, portNumber)) channelFuture.syncUninterruptibly() channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } diff --git a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala index b7b2c9884..76ccee30d 100644 --- a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala +++ b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala @@ -8,6 +8,8 @@ package org.apache.spark.deploy.dotnet import java.io.File import java.net.URI +import java.lang.NumberFormatException +import java.net.{InetAddress, URI, UnknownHostException} import java.nio.file.attribute.PosixFilePermissions import java.nio.file.{FileSystems, Files, Paths} import java.util.Locale @@ -62,6 +64,7 @@ object DotnetRunner extends Logging { // In debug mode this runner will not launch a .NET process. val runInDebugMode = settings._1 @volatile var dotnetBackendPortNumber = settings._2 + val dotnetBackendIPAddress = sys.env.getOrElse("DOTNET_SPARK_BACKEND_IP_ADDRESS", "127.0.0.1") var dotnetExecutable = "" var otherArgs: Array[String] = null @@ -109,8 +112,9 @@ object DotnetRunner extends Logging { override def run() { // need to get back dotnetBackendPortNumber because if the value passed to init is 0 // the port number is dynamically assigned in the backend - dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber) - logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber") + dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendIPAddress, dotnetBackendPortNumber) + logInfo(s"IP address used by DotnetBackend is $dotnetBackendIPAddress and " + + s"Port number used is $dotnetBackendPortNumber") initialized.release() dotnetBackend.run() } @@ -126,6 +130,7 @@ object DotnetRunner extends Logging { val builder = new ProcessBuilder(processParameters) val env = builder.environment() env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString) + env.put("DOTNETBACKEND_IP_ADDRESS", dotnetBackendIPAddress) for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { env.put(key, value) @@ -266,19 +271,30 @@ object DotnetRunner extends Logging { returnCode } - private def initializeSettings(args: Array[String]): (Boolean, Int) = { + private def initializeSettings(args: Array[String]): (Boolean, Int, String) = { val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase( "debug") var portNumber = 0 + var dotnetBackendIPAddress = "localhost" if (runInDebugMode) { if (args.length == 1) { portNumber = DEBUG_PORT } else if (args.length == 2) { - portNumber = Integer.parseInt(args(1)) + portNumber = Integer.parseInt(args(1)) + } + } + else { + try { + var addr = InetAddress.getByName(args(0)) + dotnetBackendIPAddress = args(0) + } + catch { + case e: UnknownHostException => + dotnetBackendIPAddress = "localhost" } } - (runInDebugMode, portNumber) + (runInDebugMode, portNumber, dotnetBackendIPAddress) } private def logThrowable(throwable: Throwable): Unit = diff --git a/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index c6f528aee..b59ef8b76 100644 --- a/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -34,7 +34,7 @@ class DotnetBackend extends Logging { @volatile private[dotnet] var callbackClient: Option[CallbackClient] = None - def init(portNumber: Int): Int = { + def init(ipAddress: String, portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS) logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.") @@ -62,7 +62,7 @@ class DotnetBackend extends Logging { } }) - channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber)) + channelFuture = bootstrap.bind(new InetSocketAddress(ipAddress, portNumber)) channelFuture.syncUninterruptibly() channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } diff --git a/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala index 38ecca63b..3333401ce 100644 --- a/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala +++ b/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala @@ -8,6 +8,8 @@ package org.apache.spark.deploy.dotnet import java.io.File import java.net.URI +import java.lang.NumberFormatException +import java.net.{InetAddress, URI, UnknownHostException} import java.nio.file.attribute.PosixFilePermissions import java.nio.file.{FileSystems, Files, Paths} import java.util.Locale @@ -62,6 +64,7 @@ object DotnetRunner extends Logging { // In debug mode this runner will not launch a .NET process. val runInDebugMode = settings._1 @volatile var dotnetBackendPortNumber = settings._2 + val dotnetBackendIPAddress = sys.env.getOrElse("DOTNET_SPARK_BACKEND_IP_ADDRESS", "127.0.0.1") var dotnetExecutable = "" var otherArgs: Array[String] = null @@ -109,8 +112,9 @@ object DotnetRunner extends Logging { override def run() { // need to get back dotnetBackendPortNumber because if the value passed to init is 0 // the port number is dynamically assigned in the backend - dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber) - logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber") + dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendIPAddress, dotnetBackendPortNumber) + logInfo(s"IP address used by DotnetBackend is $dotnetBackendIPAddress and " + + s"Port number used is $dotnetBackendPortNumber") initialized.release() dotnetBackend.run() } @@ -126,6 +130,7 @@ object DotnetRunner extends Logging { val builder = new ProcessBuilder(processParameters) val env = builder.environment() env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString) + env.put("DOTNETBACKEND_IP_ADDRESS", dotnetBackendIPAddress) for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { env.put(key, value) @@ -265,19 +270,30 @@ object DotnetRunner extends Logging { returnCode } - private def initializeSettings(args: Array[String]): (Boolean, Int) = { + private def initializeSettings(args: Array[String]): (Boolean, Int, String) = { val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase( "debug") var portNumber = 0 + var dotnetBackendIPAddress = "localhost" if (runInDebugMode) { if (args.length == 1) { portNumber = DEBUG_PORT } else if (args.length == 2) { - portNumber = Integer.parseInt(args(1)) + portNumber = Integer.parseInt(args(1)) + } + } + else { + try { + var addr = InetAddress.getByName(args(0)) + dotnetBackendIPAddress = args(0) + } + catch { + case e: UnknownHostException => + dotnetBackendIPAddress = "localhost" } } - (runInDebugMode, portNumber) + (runInDebugMode, portNumber, dotnetBackendIPAddress) } private def logThrowable(throwable: Throwable): Unit =