Skip to content

Commit 219f6e0

Browse files
committed
Add service package
To help facilitate new features, begin moving the main webhook service properties to a Service struct.
1 parent dff684a commit 219f6e0

File tree

6 files changed

+370
-127
lines changed

6 files changed

+370
-127
lines changed

internal/service/security/tls.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
// Package security provides HTTP security management help to the webhook
2+
// service.
3+
package security
4+
5+
import (
6+
"crypto/tls"
7+
"fmt"
8+
"io"
9+
"log"
10+
"strings"
11+
"sync"
12+
)
13+
14+
// KeyPairReloader contains the active TLS certificate. It can be used with
15+
// the tls.Config.GetCertificate property to support live updating of the
16+
// certificate.
17+
type KeyPairReloader struct {
18+
certMu sync.RWMutex
19+
cert *tls.Certificate
20+
certPath string
21+
keyPath string
22+
}
23+
24+
// NewKeyPairReloader creates a new KeyPairReloader given the certificate and
25+
// key path.
26+
func NewKeyPairReloader(certPath, keyPath string) (*KeyPairReloader, error) {
27+
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
28+
if err != nil {
29+
return nil, err
30+
}
31+
32+
res := &KeyPairReloader{
33+
cert: &cert,
34+
certPath: certPath,
35+
keyPath: keyPath,
36+
}
37+
38+
return res, nil
39+
}
40+
41+
// GetCertificateFunc provides a function for tls.Config.GetCertificate.
42+
func (kpr *KeyPairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
43+
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
44+
kpr.certMu.RLock()
45+
defer kpr.certMu.RUnlock()
46+
return kpr.cert, nil
47+
}
48+
}
49+
50+
// WriteTLSSupportedCipherStrings writes a list of ciphers to w. The list is
51+
// all supported TLS ciphers based upon min.
52+
func WriteTLSSupportedCipherStrings(w io.Writer, min string) error {
53+
m, err := GetTLSVersion(min)
54+
if err != nil {
55+
return err
56+
}
57+
58+
for _, c := range tls.CipherSuites() {
59+
var found bool
60+
61+
for _, v := range c.SupportedVersions {
62+
if v >= m {
63+
found = true
64+
}
65+
}
66+
67+
if !found {
68+
continue
69+
}
70+
71+
_, err := w.Write([]byte(c.Name + "\n"))
72+
if err != nil {
73+
return err
74+
}
75+
}
76+
77+
return nil
78+
}
79+
80+
// GetTLSVersion converts a TLS version string, v, (e.g. "v1.3") into a TLS
81+
// version ID.
82+
func GetTLSVersion(v string) (uint16, error) {
83+
switch v {
84+
case "1.3", "v1.3", "tls1.3":
85+
return tls.VersionTLS13, nil
86+
case "1.2", "v1.2", "tls1.2", "":
87+
return tls.VersionTLS12, nil
88+
case "1.1", "v1.1", "tls1.1":
89+
return tls.VersionTLS11, nil
90+
case "1.0", "v1.0", "tls1.0":
91+
return tls.VersionTLS10, nil
92+
default:
93+
return 0, fmt.Errorf("error: unknown TLS version: %s", v)
94+
}
95+
}
96+
97+
// GetTLSCipherSuites converts a comma separated list of cipher suites into a
98+
// slice of TLS cipher suite IDs.
99+
func GetTLSCipherSuites(v string) []uint16 {
100+
supported := tls.CipherSuites()
101+
102+
if v == "" {
103+
suites := make([]uint16, len(supported))
104+
105+
for _, cs := range supported {
106+
suites = append(suites, cs.ID)
107+
}
108+
109+
return suites
110+
}
111+
112+
var found bool
113+
txts := strings.Split(v, ",")
114+
suites := make([]uint16, len(txts))
115+
116+
for _, want := range txts {
117+
found = false
118+
119+
for _, cs := range supported {
120+
if want == cs.Name {
121+
suites = append(suites, cs.ID)
122+
found = true
123+
}
124+
}
125+
126+
if !found {
127+
log.Fatalln("error: unknown TLS cipher suite:", want)
128+
}
129+
}
130+
131+
return suites
132+
}
133+
134+
// GetTLSCurves converts a comma separated list of curves into a
135+
// slice of TLS curve IDs.
136+
func GetTLSCurves(v string) []tls.CurveID {
137+
supported := []tls.CurveID{
138+
tls.CurveP256,
139+
tls.CurveP384,
140+
tls.CurveP521,
141+
tls.X25519,
142+
}
143+
144+
if v == "" {
145+
return supported
146+
}
147+
148+
var found bool
149+
txts := strings.Split(v, ",")
150+
res := make([]tls.CurveID, len(txts))
151+
152+
for _, want := range txts {
153+
found = false
154+
155+
for _, c := range supported {
156+
if want == c.String() {
157+
res = append(res, c)
158+
found = true
159+
}
160+
}
161+
162+
if !found {
163+
log.Fatalln("error: unknown TLS curve:", want)
164+
}
165+
}
166+
167+
return res
168+
}

internal/service/service.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// Package service manages the webhook HTTP service.
2+
package service
3+
4+
import (
5+
"crypto/tls"
6+
"fmt"
7+
"net"
8+
"net/http"
9+
10+
"github.com/adnanh/webhook/internal/pidfile"
11+
"github.com/adnanh/webhook/internal/service/security"
12+
13+
"github.com/gorilla/mux"
14+
)
15+
16+
// Service is the webhook HTTP service.
17+
type Service struct {
18+
// Address is the listener address for the service (e.g. "127.0.0.1:9000")
19+
Address string
20+
21+
// TLS settings
22+
enableTLS bool
23+
tlsCiphers []uint16
24+
tlsMinVersion uint16
25+
kpr *security.KeyPairReloader
26+
27+
// Future TLS settings to consider:
28+
// - tlsMaxVersion
29+
// - configurable TLS curves
30+
// - modern and intermediate helpers that follows Mozilla guidelines
31+
// - ca root and intermediate certs
32+
33+
listener net.Listener
34+
server *http.Server
35+
36+
pidFile *pidfile.PIDFile
37+
38+
// Hooks map[string]hook.Hooks
39+
}
40+
41+
// New creates a new webhook HTTP service for the given address and port.
42+
func New(ip string, port int) *Service {
43+
return &Service{
44+
Address: fmt.Sprintf("%s:%d", ip, port),
45+
server: &http.Server{},
46+
tlsMinVersion: tls.VersionTLS12,
47+
}
48+
}
49+
50+
// Listen announces the TCP service on the local network.
51+
//
52+
// To enable TLS, ensure that SetTLSEnabled is called prior to Listen.
53+
//
54+
// After calling Listen, Serve must be called to begin serving HTTP requests.
55+
// The steps are separated so that we can drop privileges, if necessary, after
56+
// opening the listening port.
57+
func (s *Service) Listen() error {
58+
ln, err := net.Listen("tcp", s.Address)
59+
if err != nil {
60+
return err
61+
}
62+
63+
if !s.enableTLS {
64+
s.listener = ln
65+
return nil
66+
}
67+
68+
if s.kpr == nil {
69+
panic("Listen called with TLS enabled but KPR is nil")
70+
}
71+
72+
c := &tls.Config{
73+
GetCertificate: s.kpr.GetCertificateFunc(),
74+
CipherSuites: s.tlsCiphers,
75+
CurvePreferences: security.GetTLSCurves(""),
76+
MinVersion: s.tlsMinVersion,
77+
PreferServerCipherSuites: true,
78+
}
79+
80+
s.listener = tls.NewListener(ln, c)
81+
82+
return nil
83+
}
84+
85+
// Serve begins accepting incoming HTTP connections.
86+
func (s *Service) Serve() error {
87+
s.server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2
88+
89+
if s.listener == nil {
90+
err := s.Listen()
91+
if err != nil {
92+
return err
93+
}
94+
}
95+
96+
defer s.listener.Close()
97+
return s.server.Serve(s.listener)
98+
}
99+
100+
// SetHTTPHandler sets the underly HTTP server Handler.
101+
func (s *Service) SetHTTPHandler(r *mux.Router) {
102+
s.server.Handler = r
103+
}
104+
105+
// SetTLSCiphers sets the supported TLS ciphers.
106+
func (s *Service) SetTLSCiphers(suites string) {
107+
s.tlsCiphers = security.GetTLSCipherSuites(suites)
108+
}
109+
110+
// SetTLSEnabled enables TLS for the service. Must be called prior to Listen.
111+
func (s *Service) SetTLSEnabled() {
112+
s.enableTLS = true
113+
}
114+
115+
// SetTLSKeyPair sets the TLS key pair for the service.
116+
func (s *Service) SetTLSKeyPair(certPath, keyPath string) error {
117+
if certPath == "" {
118+
return fmt.Errorf("error: certificate required for TLS")
119+
}
120+
121+
if keyPath == "" {
122+
return fmt.Errorf("error: key required for TLS")
123+
}
124+
125+
var err error
126+
127+
s.kpr, err = security.NewKeyPairReloader(certPath, keyPath)
128+
if err != nil {
129+
return err
130+
}
131+
132+
return nil
133+
}
134+
135+
// SetTLSMinVersion sets the minimum support TLS version, such as "v1.3".
136+
func (s *Service) SetTLSMinVersion(ver string) (err error) {
137+
s.tlsMinVersion, err = security.GetTLSVersion(ver)
138+
return err
139+
}
140+
141+
// CreatePIDFile creates a new PID file at path p.
142+
func (s *Service) CreatePIDFile(p string) (err error) {
143+
s.pidFile, err = pidfile.New(p)
144+
return err
145+
}
146+
147+
// DeletePIDFile deletes a previously created PID file.
148+
func (s *Service) DeletePIDFile() error {
149+
if s.pidFile != nil {
150+
return s.pidFile.Remove()
151+
}
152+
return nil
153+
}

signals.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import (
77
"os"
88
"os/signal"
99
"syscall"
10+
11+
"github.com/adnanh/webhook/internal/service"
1012
)
1113

12-
func setupSignals() {
14+
func setupSignals(svc *service.Service) {
1315
log.Printf("setting up os signal watcher\n")
1416

1517
signals = make(chan os.Signal, 1)
@@ -18,10 +20,10 @@ func setupSignals() {
1820
signal.Notify(signals, syscall.SIGTERM)
1921
signal.Notify(signals, os.Interrupt)
2022

21-
go watchForSignals()
23+
go watchForSignals(svc)
2224
}
2325

24-
func watchForSignals() {
26+
func watchForSignals(svc *service.Service) {
2527
log.Println("os signal watcher ready")
2628

2729
for {
@@ -37,11 +39,9 @@ func watchForSignals() {
3739

3840
case os.Interrupt, syscall.SIGTERM:
3941
log.Printf("caught %s signal; exiting\n", sig)
40-
if pidFile != nil {
41-
err := pidFile.Remove()
42-
if err != nil {
43-
log.Print(err)
44-
}
42+
err := svc.DeletePIDFile()
43+
if err != nil {
44+
log.Print(err)
4545
}
4646
os.Exit(0)
4747

signals_windows.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
package main
44

5-
func setupSignals() {
5+
import "github.com/adnanh/webhook/internal/service"
6+
7+
func setupSignals(_ *service.Service) {
68
// NOOP: Windows doesn't have signals equivalent to the Unix world.
79
}

0 commit comments

Comments
 (0)