3131import com .github .shyiko .mysql .binlog .io .ByteArrayInputStream ;
3232import com .github .shyiko .mysql .binlog .jmx .BinaryLogClientMXBean ;
3333import com .github .shyiko .mysql .binlog .network .AuthenticationException ;
34+ import com .github .shyiko .mysql .binlog .network .ClientCapabilities ;
35+ import com .github .shyiko .mysql .binlog .network .DefaultSSLSocketFactory ;
36+ import com .github .shyiko .mysql .binlog .network .SSLMode ;
37+ import com .github .shyiko .mysql .binlog .network .SSLSocketFactory ;
3438import com .github .shyiko .mysql .binlog .network .ServerException ;
3539import com .github .shyiko .mysql .binlog .network .SocketFactory ;
40+ import com .github .shyiko .mysql .binlog .network .TLSHostnameVerifier ;
3641import com .github .shyiko .mysql .binlog .network .protocol .ErrorPacket ;
3742import com .github .shyiko .mysql .binlog .network .protocol .GreetingPacket ;
3843import com .github .shyiko .mysql .binlog .network .protocol .Packet ;
4449import com .github .shyiko .mysql .binlog .network .protocol .command .DumpBinaryLogGtidCommand ;
4550import com .github .shyiko .mysql .binlog .network .protocol .command .PingCommand ;
4651import com .github .shyiko .mysql .binlog .network .protocol .command .QueryCommand ;
52+ import com .github .shyiko .mysql .binlog .network .protocol .command .SSLRequestCommand ;
4753
54+ import javax .net .ssl .SSLContext ;
55+ import javax .net .ssl .TrustManager ;
56+ import javax .net .ssl .X509TrustManager ;
4857import java .io .EOFException ;
4958import java .io .IOException ;
5059import java .net .InetSocketAddress ;
5160import java .net .Socket ;
5261import java .net .SocketException ;
62+ import java .security .GeneralSecurityException ;
63+ import java .security .cert .CertificateException ;
64+ import java .security .cert .X509Certificate ;
5365import java .util .Arrays ;
5466import java .util .Collections ;
5567import java .util .Iterator ;
7486 */
7587public class BinaryLogClient implements BinaryLogClientMXBean {
7688
89+ private static final SSLSocketFactory DEFAULT_REQUIRED_SSL_MODE_SOCKET_FACTORY = new DefaultSSLSocketFactory () {
90+
91+ @ Override
92+ protected void initSSLContext (SSLContext sc ) throws GeneralSecurityException {
93+ sc .init (null , new TrustManager []{
94+ new X509TrustManager () {
95+
96+ @ Override
97+ public void checkClientTrusted (X509Certificate [] x509Certificates , String s )
98+ throws CertificateException { }
99+
100+ @ Override
101+ public void checkServerTrusted (X509Certificate [] x509Certificates , String s )
102+ throws CertificateException { }
103+
104+ @ Override
105+ public X509Certificate [] getAcceptedIssuers () {
106+ return new X509Certificate [0 ];
107+ }
108+ }
109+ }, null );
110+ }
111+ };
112+ private static final SSLSocketFactory DEFAULT_VERIFY_CA_SSL_MODE_SOCKET_FACTORY = new DefaultSSLSocketFactory ();
113+
77114 // https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
78115 private static final int MAX_PACKET_LENGTH = 16777215 ;
79116
@@ -90,6 +127,7 @@ public class BinaryLogClient implements BinaryLogClientMXBean {
90127 private volatile String binlogFilename ;
91128 private volatile long binlogPosition = 4 ;
92129 private volatile long connectionId ;
130+ private SSLMode sslMode = SSLMode .DISABLED ;
93131
94132 private GtidSet gtidSet ;
95133 private final Object gtidSetAccessLock = new Object ();
@@ -100,6 +138,7 @@ public class BinaryLogClient implements BinaryLogClientMXBean {
100138 private final List <LifecycleListener > lifecycleListeners = new LinkedList <LifecycleListener >();
101139
102140 private SocketFactory socketFactory ;
141+ private SSLSocketFactory sslSocketFactory ;
103142
104143 private PacketChannel channel ;
105144 private volatile boolean connected ;
@@ -166,6 +205,17 @@ public void setBlocking(boolean blocking) {
166205 this .blocking = blocking ;
167206 }
168207
208+ public SSLMode getSSLMode () {
209+ return sslMode ;
210+ }
211+
212+ public void setSSLMode (SSLMode sslMode ) {
213+ if (sslMode == null ) {
214+ throw new IllegalArgumentException ("SSL mode cannot be NULL" );
215+ }
216+ this .sslMode = sslMode ;
217+ }
218+
169219 /**
170220 * @return server id (65535 by default)
171221 * @see #setServerId(long)
@@ -326,6 +376,13 @@ public void setSocketFactory(SocketFactory socketFactory) {
326376 this .socketFactory = socketFactory ;
327377 }
328378
379+ /**
380+ * @param sslSocketFactory custom ssl socket factory
381+ */
382+ public void setSslSocketFactory (SSLSocketFactory sslSocketFactory ) {
383+ this .sslSocketFactory = sslSocketFactory ;
384+ }
385+
329386 /**
330387 * @param threadFactory custom thread factory. If not provided, threads will be created using simple "new Thread()".
331388 */
@@ -357,7 +414,7 @@ public void connect() throws IOException {
357414 ". Please make sure it's running." , e );
358415 }
359416 greetingPacket = receiveGreeting ();
360- authenticate (greetingPacket . getScramble (), greetingPacket . getServerCollation () );
417+ authenticate (greetingPacket );
361418 if (binlogFilename == null ) {
362419 fetchBinlogFilenameAndPosition ();
363420 }
@@ -446,10 +503,30 @@ private void ensureEventDataDeserializer(EventType eventType,
446503 }
447504 }
448505
449- private void authenticate (String salt , int collation ) throws IOException {
450- AuthenticateCommand authenticateCommand = new AuthenticateCommand (schema , username , password , salt );
506+ private void authenticate (GreetingPacket greetingPacket ) throws IOException {
507+ int collation = greetingPacket .getServerCollation ();
508+ int packetNumber = 1 ;
509+ if (sslMode != SSLMode .DISABLED ) {
510+ boolean serverSupportsSSL = (greetingPacket .getServerCapabilities () & ClientCapabilities .SSL ) != 0 ;
511+ if (!serverSupportsSSL && (sslMode == SSLMode .REQUIRED || sslMode == SSLMode .VERIFY_CA ||
512+ sslMode == SSLMode .VERIFY_IDENTITY )) {
513+ throw new IOException ("MySQL server does not support SSL" );
514+ }
515+ if (serverSupportsSSL ) {
516+ SSLRequestCommand sslRequestCommand = new SSLRequestCommand ();
517+ sslRequestCommand .setCollation (collation );
518+ channel .write (sslRequestCommand , packetNumber ++);
519+ SSLSocketFactory sslSocketFactory = this .sslSocketFactory != null ? this .sslSocketFactory :
520+ sslMode == SSLMode .REQUIRED ? DEFAULT_REQUIRED_SSL_MODE_SOCKET_FACTORY :
521+ DEFAULT_VERIFY_CA_SSL_MODE_SOCKET_FACTORY ;
522+ channel .upgradeToSSL (sslSocketFactory ,
523+ sslMode == SSLMode .VERIFY_IDENTITY ? new TLSHostnameVerifier () : null );
524+ }
525+ }
526+ AuthenticateCommand authenticateCommand = new AuthenticateCommand (schema , username , password ,
527+ greetingPacket .getScramble ());
451528 authenticateCommand .setCollation (collation );
452- channel .write (authenticateCommand );
529+ channel .write (authenticateCommand , packetNumber );
453530 byte [] authenticationResult = channel .read ();
454531 if (authenticationResult [0 ] != (byte ) 0x00 /* ok */ ) {
455532 if (authenticationResult [0 ] == (byte ) 0xFF /* error */ ) {
0 commit comments