Skip to content

Commit bf0387b

Browse files
committed
rfq: allow os, custom certificates
Adds both 'TrustSystemRootCAs' and 'CustomCertificates' to the rfq TLSConfig. The former indicates whether or not to trust the operating system's root CA list; the latter allows additional certificates (CA or self-signed) to be trusted. Also adds a basic unit test skeleton.
1 parent 166bc8c commit bf0387b

File tree

3 files changed

+152
-41
lines changed

3 files changed

+152
-41
lines changed

rfq/oracle.go

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package rfq
22

33
import (
44
"context"
5-
"crypto/tls"
65
"fmt"
76
"math"
87
"net/url"
@@ -16,8 +15,6 @@ import (
1615
"github.com/lightningnetwork/lnd/lnwire"
1716
"github.com/lightningnetwork/lnd/routing/route"
1817
"google.golang.org/grpc"
19-
"google.golang.org/grpc/credentials"
20-
"google.golang.org/grpc/credentials/insecure"
2118
)
2219

2320
// PriceQueryIntent is an enum that represents the intent of a price rate
@@ -186,32 +183,6 @@ type RpcPriceOracle struct {
186183
rawConn *grpc.ClientConn
187184
}
188185

189-
// serverDialOpts returns the set of server options needed to connect to the
190-
// price oracle RPC server using a TLS connection.
191-
func serverDialOpts() ([]grpc.DialOption, error) {
192-
var opts []grpc.DialOption
193-
194-
// Skip TLS certificate verification.
195-
tlsConfig := tls.Config{InsecureSkipVerify: true}
196-
transportCredentials := credentials.NewTLS(&tlsConfig)
197-
opts = append(opts, grpc.WithTransportCredentials(transportCredentials))
198-
199-
return opts, nil
200-
}
201-
202-
// insecureServerDialOpts returns the set of server options needed to connect to
203-
// the price oracle RPC server using a TLS connection.
204-
func insecureServerDialOpts() ([]grpc.DialOption, error) {
205-
var opts []grpc.DialOption
206-
207-
// Skip TLS certificate verification.
208-
opts = append(opts, grpc.WithTransportCredentials(
209-
insecure.NewCredentials(),
210-
))
211-
212-
return opts, nil
213-
}
214-
215186
// NewRpcPriceOracle creates a new RPC price oracle handle given the address
216187
// of the price oracle RPC server.
217188
func NewRpcPriceOracle(addrStr string, tlsConfig *TLSConfig) (*RpcPriceOracle,
@@ -222,27 +193,21 @@ func NewRpcPriceOracle(addrStr string, tlsConfig *TLSConfig) (*RpcPriceOracle,
222193
return nil, err
223194
}
224195

225-
// Connect to the RPC server.
226-
dialOpts, err := serverDialOpts()
196+
// Create transport credentials and dial options from the supplied TLS
197+
// config.
198+
transportCredentials, err := configureTransportCredentials(tlsConfig)
227199
if err != nil {
228200
return nil, err
229201
}
230202

231-
// Determine whether we should skip certificate verification.
232-
dialInsecure := tlsConfig.InsecureSkipVerify
233-
234-
// Allow connecting to a non-TLS (h2c, http over cleartext) gRPC server,
235-
// should be used for testing only.
236-
if dialInsecure {
237-
dialOpts, err = insecureServerDialOpts()
238-
if err != nil {
239-
return nil, err
240-
}
203+
dialOpts := []grpc.DialOption{
204+
grpc.WithTransportCredentials(transportCredentials),
241205
}
242206

243207
// Formulate the server address dial string.
244208
serverAddr := fmt.Sprintf("%s:%s", addr.Hostname(), addr.Port())
245209

210+
// Connect to the RPC server.
246211
conn, err := grpc.Dial(serverAddr, dialOpts...)
247212
if err != nil {
248213
return nil, err

rfq/tls.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
package rfq
22

3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
7+
"google.golang.org/grpc/credentials"
8+
"google.golang.org/grpc/credentials/insecure"
9+
)
10+
311
// TLSConfig represents TLS configuration options for oracle connections.
412
type TLSConfig struct {
13+
// Enabled indicates that we should use TLS.
14+
Enabled bool
15+
516
// InsecureSkipVerify disables certificate verification.
617
InsecureSkipVerify bool
18+
19+
// TrustSystemRootCAs indicates whether or not to use the operating
20+
// system's root certificate authority list.
21+
TrustSystemRootCAs bool
22+
23+
// CustomCertificates contains PEM data for additional root CA and
24+
// self-signed certificates to trust.
25+
CustomCertificates []byte
726
}
827

928
// DefaultTLSConfig returns a default TLS configuration.
@@ -12,3 +31,45 @@ func DefaultTLSConfig() *TLSConfig {
1231
InsecureSkipVerify: true,
1332
}
1433
}
34+
35+
// configureTransportCredentials configures the TLS transport credentials to
36+
// be used for RPC connections.
37+
func configureTransportCredentials(
38+
config *TLSConfig) (credentials.TransportCredentials, error) {
39+
40+
// If TLS is disabled, return insecure credentials.
41+
if !config.Enabled {
42+
return insecure.NewCredentials(), nil
43+
}
44+
45+
// If we're to skip certificate verification, then return TLS
46+
// credentials with certificate verification disabled.
47+
if config.InsecureSkipVerify {
48+
creds := credentials.NewTLS(&tls.Config{
49+
InsecureSkipVerify: true,
50+
})
51+
return creds, nil
52+
}
53+
54+
// Initialize the certificate pool.
55+
certPool, err := constructCertPool(config.TrustSystemRootCAs)
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
// If we have any custom certificates, add them to the certificate
61+
// pool.
62+
certPool.AppendCertsFromPEM(config.CustomCertificates)
63+
64+
// Return the constructed transport credentials.
65+
return credentials.NewClientTLSFromCert(certPool, ""), nil
66+
}
67+
68+
// constructCertPool is a helper for constructing an initial certificate pool,
69+
// depending on whether or not we should trust the system root CA list.
70+
func constructCertPool(trustSystem bool) (*x509.CertPool, error) {
71+
if trustSystem {
72+
return x509.SystemCertPool()
73+
}
74+
return x509.NewCertPool(), nil
75+
}

rfq/tls_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package rfq
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
// Test certificate data - a valid self-signed certificate for testing
10+
const validTestCertPEM = `-----BEGIN CERTIFICATE-----
11+
MIICmjCCAYICCQCuu1gzY+BBKjANBgkqhkiG9w0BAQsFADAPMQ0wCwYDVQQDDAR0
12+
ZXN0MB4XDTI1MDgyODEwNDA1NVoXDTI1MDgyOTEwNDA1NVowDzENMAsGA1UEAwwE
13+
dGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALTWCm8l3d9nE2QK
14+
TK8HJ36ftO8pK3//nb8Nj/p97FrPFSgzdgL1ZNJs4gP5/ZsU+iE6VeKhalHoSf6/
15+
IMLe3ATTL0rWA1M6z7cw6ll8VS8NQMaMSFWNomncsxyoJAQde++SC5f1RwQJBD/0
16+
gGB4bJIIqUHtT12m23GLX48d6JGEEi5kEQtk91S/QGnHtglzZ8CQOogDBzDhSHu2
17+
jj4mKYDgkXcyAqN7DoDzoEcrpeAaeAwem8k1sFBeTtrqT1ot7Ey5KG+RUyJbdKGt
18+
5adJiwH782NgsSnISQ2X7Sct6Uu0JzHKx9JzyABsA05tf3cNJkLhh1Is9edYI2e9
19+
m0dqedECAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAQOCs/7xZVPjabbhdv30mUJMG
20+
lddi2A+R/5IRXW1MKnpemwiv4ZWYQ9PMTmuR7kqaF7AGLkvx+5sp2evUJN4x7vHP
21+
ao6wihbdh+vBkrobE+Y9dE7nbkvMQSNi1sXzDnfZB9LqY9Huun2soUwBQNCMPVMa
22+
Wo7g6udwyA48doEVJMjThFLPcW7xmsy6Ldew682m1kD8/ag+9qihX1IJyiqiEjha
23+
3uT4CT+zEg0RJorEJKbR38fE4Uhx1wZO4zvjEg6qZeW/I4lw+UzSY5xV7lJ1EQvf
24+
BcoNuBHB65RxQM5fpA7hkEFm1bxBoowGX2hx6VCCeBBwREISRfgvkUxZahUXNg==
25+
-----END CERTIFICATE-----`
26+
27+
// Invalid PEM data for testing failure cases
28+
const invalidTestCertPEM = `-----BEGIN CERTIFICATE-----
29+
This is not a valid certificate
30+
-----END CERTIFICATE-----`
31+
32+
// DefaultTLSConfig returns a default TLS configuration for testing.
33+
func DefaultTLSConfig() *TLSConfig {
34+
return &TLSConfig{
35+
InsecureSkipVerify: true,
36+
}
37+
}
38+
39+
// TestConfigureTransportCredentials_InsecureSkipVerify tests the function
40+
// when InsecureSkipVerify is true.
41+
func TestConfigureTransportCredentials_InsecureSkipVerify(t *testing.T) {
42+
config := &TLSConfig{
43+
InsecureSkipVerify: true,
44+
}
45+
46+
creds, err := configureTransportCredentials(config)
47+
48+
require.NoError(t, err)
49+
require.NotNil(t, creds)
50+
51+
// Verify that we got insecure credentials by checking the type
52+
require.Equal(t, "insecure", creds.Info().SecurityProtocol)
53+
}
54+
55+
// TestConfigureTransportCredentials_ValidCustomCertificates tests the
56+
// function when valid custom certificates are provided.
57+
func TestConfigureTransportCredentials_ValidCustomCertificates(t *testing.T) {
58+
config := &TLSConfig{
59+
InsecureSkipVerify: false,
60+
CustomCertificates: []byte(validTestCertPEM),
61+
}
62+
63+
creds, err := configureTransportCredentials(config)
64+
65+
require.NoError(t, err)
66+
require.NotNil(t, creds)
67+
68+
// Verify that we got TLS credentials (not insecure)
69+
require.Equal(t, "tls", creds.Info().SecurityProtocol)
70+
}
71+
72+
// TestConfigureTransportCredentials_NoCredentialsConfigured tests the
73+
// function when no credentials are configured.
74+
func TestConfigureTransportCredentials_NoCredentialsConfigured(t *testing.T) {
75+
config := &TLSConfig{
76+
InsecureSkipVerify: false,
77+
CustomCertificates: nil,
78+
}
79+
80+
creds, err := configureTransportCredentials(config)
81+
82+
require.NoError(t, err)
83+
require.NotNil(t, creds)
84+
require.Equal(t, "tls", creds.Info().SecurityProtocol)
85+
}

0 commit comments

Comments
 (0)