11package main
22
33import (
4+ "context"
45 "flag"
56 "fmt"
67 "os"
8+ "os/signal"
79 "strings"
10+ "sync"
11+ "syscall"
12+ "time"
813
914 "file-encryptor/pkg/crypto"
1015 "file-encryptor/pkg/fileops"
@@ -26,6 +31,23 @@ func main() {
2631 logger := logging .NewLogger ()
2732 logger .LogDebug ("Starting file encryptor" )
2833
34+ // Create a context that can be cancelled
35+ ctx , cancel := context .WithCancel (context .Background ())
36+ defer cancel ()
37+
38+ // Set up signal handling for graceful shutdown
39+ signalChan := make (chan os.Signal , 1 )
40+ signal .Notify (signalChan , os .Interrupt , syscall .SIGTERM )
41+
42+ // Handle cancellation in a separate goroutine
43+ go func () {
44+ select {
45+ case <- signalChan :
46+ logger .LogInfo ("Received termination signal, shutting down gracefully..." )
47+ cancel ()
48+ }
49+ }()
50+
2951 encrypt := flag .Bool ("e" , false , "Encrypt the file" )
3052 decrypt := flag .Bool ("d" , false , "Decrypt the file" )
3153
@@ -43,8 +65,18 @@ func main() {
4365
4466 generateKeys := flag .Bool ("generate-keys" , false , "Generate a new RSA key pair" )
4567 keyBaseName := flag .String ("key-name" , "key" , "Base name for the generated key files" )
68+
69+ timeout := flag .Duration ("timeout" , 30 * time .Minute , "Timeout for the entire operation" )
4670
4771 flag .Parse ()
72+
73+ // Set a timeout for the entire operation
74+ if * timeout > 0 {
75+ ctx , cancel = context .WithTimeout (ctx , * timeout )
76+ defer cancel ()
77+ logger .LogDebug (fmt .Sprintf ("Operation will timeout after %v" , * timeout ))
78+ }
79+
4880 // Add remaining arguments as files only if they don't start with "-"
4981 // This prevents treating flags like "-k" as files
5082 remainingArgs := flag .Args ()
@@ -62,7 +94,7 @@ func main() {
6294 logger .LogDebug (fmt .Sprintf ("Generate Keys: %v, Key Base Name: %s" , * generateKeys , * keyBaseName ))
6395
6496 if * generateKeys && ! * encrypt && len (files ) == 0 {
65- if err := handleGenerateKeys (* keyBaseName , logger ); err != nil {
97+ if err := handleGenerateKeys (ctx , * keyBaseName , logger ); err != nil {
6698 logger .LogError (err .Error ())
6799 os .Exit (1 )
68100 }
@@ -82,13 +114,19 @@ func main() {
82114
83115 if * generateKeys && * encrypt && len (files ) > 0 {
84116 operation = "Encryption with key generation"
85- outputFiles , err = handleGenerateAndEncrypt (* keyBaseName , files , logger )
117+ outputFiles , err = handleGenerateAndEncrypt (ctx , * keyBaseName , files , logger )
86118 } else if * encrypt {
87119 operation = "Encryption"
88- outputFiles , err = handleEncryption (files , key , password , logger )
120+ outputFiles , err = handleEncryption (ctx , files , key , password , logger )
89121 } else {
90122 operation = "Decryption"
91- outputFiles , err = handleDecryption (files , key , password , logger )
123+ outputFiles , err = handleDecryption (ctx , files , key , password , logger )
124+ }
125+
126+ // Check if the operation was cancelled
127+ if ctx .Err () != nil {
128+ logger .LogError (fmt .Sprintf ("Operation cancelled: %v" , ctx .Err ()))
129+ os .Exit (1 )
92130 }
93131
94132 if err != nil {
@@ -139,9 +177,16 @@ func validateFlags(encrypt, decrypt bool, files []string, key, password string,
139177 return nil
140178}
141179
142- func handleGenerateKeys (keyBaseName string , logger * logging.Logger ) error {
180+ func handleGenerateKeys (ctx context. Context , keyBaseName string , logger * logging.Logger ) error {
143181 logger .LogInfo ("Starting RSA key pair generation" )
144182
183+ // Check for cancellation
184+ select {
185+ case <- ctx .Done ():
186+ return ctx .Err ()
187+ default :
188+ }
189+
145190 if err := crypto .GenerateRSAKeyPair (keyBaseName , logger ); err != nil {
146191 return fmt .Errorf ("failed to generate RSA key pair: %w" , err )
147192 }
@@ -150,9 +195,16 @@ func handleGenerateKeys(keyBaseName string, logger *logging.Logger) error {
150195 return nil
151196}
152197
153- func handleGenerateAndEncrypt (keyBaseName string , files []string , logger * logging.Logger ) ([]string , error ) {
198+ func handleGenerateAndEncrypt (ctx context. Context , keyBaseName string , files []string , logger * logging.Logger ) ([]string , error ) {
154199 logger .LogInfo ("Starting RSA key pair generation and file encryption" )
155200
201+ // Check for cancellation
202+ select {
203+ case <- ctx .Done ():
204+ return nil , ctx .Err ()
205+ default :
206+ }
207+
156208 privateKeyName , publicKeyName , err := crypto .GenerateRSAKeyPairWithNames (keyBaseName , logger )
157209 if err != nil {
158210 return nil , fmt .Errorf ("failed to generate RSA key pair: %w" , err )
@@ -163,17 +215,10 @@ func handleGenerateAndEncrypt(keyBaseName string, files []string, logger *loggin
163215 return nil , fmt .Errorf ("failed to create encryptor: %w" , err )
164216 }
165217
166- outputFiles := make ([]string , 0 , len (files ))
167- for _ , file := range files {
168- if hash , err := crypto .CalculateFileHash (file ); err == nil {
169- logger .LogDebug (fmt .Sprintf ("Original file hash for %s: %s" , file , hash ))
170- }
171-
172- outputFile := file + ".enc"
173- if err := fileops .EncryptFile (file , outputFile , encryptor , logger ); err != nil {
174- return outputFiles , fmt .Errorf ("failed to encrypt %s: %w" , file , err )
175- }
176- outputFiles = append (outputFiles , outputFile )
218+ // Use the process files function with the context
219+ outputFiles , err := processFiles (ctx , files , true , encryptor , logger )
220+ if err != nil {
221+ return outputFiles , err
177222 }
178223
179224 logger .LogInfo (fmt .Sprintf ("Private key saved to: %s" , privateKeyName ))
@@ -212,7 +257,9 @@ func initializeCrypto(isEncryption bool, key, password string, logger *logging.L
212257}
213258
214259// processFiles handles the common file processing logic for both encryption and decryption
260+ // with support for concurrent processing of multiple files
215261func processFiles (
262+ ctx context.Context ,
216263 files []string ,
217264 isEncryption bool ,
218265 cryptoProcessor interface {},
@@ -221,9 +268,37 @@ func processFiles(
221268 operation := map [bool ]string {true : "encryption" , false : "decryption" }[isEncryption ]
222269 outputFiles := make ([]string , 0 , len (files ))
223270
271+ // Check for cancellation
272+ select {
273+ case <- ctx .Done ():
274+ return nil , ctx .Err ()
275+ default :
276+ }
277+
224278 logger .LogInfo (fmt .Sprintf ("Found %d files to process" , len (files )))
225279
226- for _ , file := range files {
280+ // Use a worker pool for concurrent processing
281+ // For small numbers of files, this might be overkill,
282+ // but for many files it will provide performance benefits
283+ numWorkers := 3 // Limit concurrency to avoid system overload
284+ if len (files ) < numWorkers {
285+ numWorkers = len (files )
286+ }
287+
288+ // Channel for job results
289+ type result struct {
290+ outputFile string
291+ err error
292+ }
293+ resultChan := make (chan result )
294+
295+ // Process files concurrently using a worker pool
296+ var wg sync.WaitGroup
297+
298+ // Function to process a single file
299+ processFile := func (file string ) {
300+ defer wg .Done ()
301+
227302 var outputFile string
228303 var processErr error
229304
@@ -242,37 +317,80 @@ func processFiles(
242317 processErr = fileops .DecryptFile (file , outputFile , cryptoProcessor .(crypto.Decryptor ), logger )
243318 }
244319
245- if processErr != nil {
246- if ! isEncryption && strings .Contains (processErr .Error (), "file integrity check failed" ) {
247- return outputFiles , fmt .Errorf ("security error while decrypting %s: %v" , file , processErr )
320+ // Send result through channel
321+ resultChan <- result {outputFile , processErr }
322+ }
323+
324+ // Limit the number of concurrent goroutines with a semaphore
325+ semaphore := make (chan struct {}, numWorkers )
326+
327+ // Start all file processing goroutines
328+ for _ , file := range files {
329+ wg .Add (1 )
330+
331+ // Acquire semaphore slot
332+ semaphore <- struct {}{}
333+
334+ go func (file string ) {
335+ processFile (file )
336+ // Release semaphore slot when done
337+ <- semaphore
338+ }(file )
339+ }
340+
341+ // Close the results channel when all workers are done
342+ go func () {
343+ wg .Wait ()
344+ close (resultChan )
345+ }()
346+
347+ // Collect results as they come in
348+ for res := range resultChan {
349+ if res .err != nil {
350+ // Let any running goroutines finish, but we'll return the error
351+ if ! isEncryption && strings .Contains (res .err .Error (), "file integrity check failed" ) {
352+ return outputFiles , fmt .Errorf ("security error while decrypting: %v" , res .err )
248353 }
249- return outputFiles , fmt .Errorf ("failed to %s %s : %w" , operation , file , processErr )
354+ return outputFiles , fmt .Errorf ("failed to %s: %w" , operation , res . err )
250355 }
251-
252- outputFiles = append (outputFiles , outputFile )
356+ outputFiles = append (outputFiles , res .outputFile )
253357 }
254358
255359 return outputFiles , nil
256360}
257361
258- func handleEncryption (files []string , key , password string , logger * logging.Logger ) ([]string , error ) {
362+ func handleEncryption (ctx context. Context , files []string , key , password string , logger * logging.Logger ) ([]string , error ) {
259363 logger .LogInfo ("Starting file encryption" )
260364
365+ // Check for cancellation
366+ select {
367+ case <- ctx .Done ():
368+ return nil , ctx .Err ()
369+ default :
370+ }
371+
261372 encryptor , err := initializeCrypto (true , key , password , logger )
262373 if err != nil {
263374 return nil , err
264375 }
265376
266- return processFiles (files , true , encryptor , logger )
377+ return processFiles (ctx , files , true , encryptor , logger )
267378}
268379
269- func handleDecryption (files []string , key , password string , logger * logging.Logger ) ([]string , error ) {
380+ func handleDecryption (ctx context. Context , files []string , key , password string , logger * logging.Logger ) ([]string , error ) {
270381 logger .LogInfo ("Starting file decryption" )
271382
383+ // Check for cancellation
384+ select {
385+ case <- ctx .Done ():
386+ return nil , ctx .Err ()
387+ default :
388+ }
389+
272390 decryptor , err := initializeCrypto (false , key , password , logger )
273391 if err != nil {
274392 return nil , err
275393 }
276394
277- return processFiles (files , false , decryptor , logger )
395+ return processFiles (ctx , files , false , decryptor , logger )
278396}
0 commit comments