@@ -2,27 +2,99 @@ package utils
22
33import (
44 "archive/tar"
5+ "archive/zip"
56 "compress/gzip"
67 "fmt"
78 "io"
9+ "log"
810 "log/slog"
11+ "net/http"
912 "os"
1013 "path/filepath"
1114 "strings"
1215)
1316
14- // ExtractArchiveFile extracts a .tar.gz / .tgz file located at archivePath,
15- // using outputDir as the root of the extracted files.
17+ // ExtractArchiveFile extracts a .tar.gz / .tgz file or .zip file located
18+ // at archivePath, using outputDir as the root of the extracted files.
1619func ExtractArchiveFile (archivePath string , outputDir string ) error {
20+ if outputDir == "" {
21+ return fmt .Errorf ("outputDir is empty" )
22+ }
23+
1724 f , err := os .Open (archivePath )
1825 if err != nil {
1926 return err
2027 }
2128 defer f .Close ()
2229
23- return processGzipFile (f , func (reader io.Reader ) error {
24- return extractTar (reader , outputDir )
25- })
30+ fileType , err := detectFileType (f )
31+ if err != nil {
32+ return err
33+ }
34+
35+ if fileType == "application/zip" {
36+ return processZipFile (archivePath , outputDir )
37+ } else if fileType == "application/x-gzip" {
38+ return processGzipFile (f , func (reader io.Reader ) error {
39+ return extractTar (reader , outputDir )
40+ })
41+ } else {
42+ return fmt .Errorf ("%s is not a supported archive file type: %s" , archivePath , fileType )
43+ }
44+ }
45+
46+ func processZipFile (filePath , outputDir string ) error {
47+ if err := os .MkdirAll (outputDir , os .ModePerm ); err != nil {
48+ return fmt .Errorf ("create dir for %s failed: %w" , outputDir , err )
49+ }
50+
51+ zipReader , err := zip .OpenReader (filePath )
52+ if err != nil {
53+ log .Fatal (err )
54+ }
55+ defer zipReader .Close ()
56+
57+ for _ , file := range zipReader .File {
58+ outputPath := filepath .Join (outputDir , file .Name )
59+
60+ // check for ZipSlip (https://snyk.io/research/zip-slip-vulnerability) by ensuring
61+ // outputPath (cleaned) actually is inside output directory that was specified
62+ if ! strings .HasPrefix (outputPath , filepath .Join (outputDir )+ string (os .PathSeparator )) {
63+ // Note: this error string is used in a test
64+ return fmt .Errorf ("archive path escapes output dir: %s" , file .Name )
65+ }
66+
67+ if file .FileInfo ().IsDir () {
68+ if err := os .MkdirAll (outputPath , 0o755 ); err != nil {
69+ return err
70+ }
71+ continue
72+ }
73+
74+ // Ensure parent directories exist before creating the file
75+ if err := os .MkdirAll (filepath .Dir (outputPath ), 0o755 ); err != nil {
76+ return err
77+ }
78+
79+ srcFile , err := file .Open ()
80+ if err != nil {
81+ return err
82+ }
83+ defer srcFile .Close ()
84+
85+ destFile , err := os .Create (outputPath )
86+ if err != nil {
87+ return err
88+ }
89+ defer destFile .Close ()
90+
91+ _ , err = io .Copy (destFile , srcFile )
92+ if err != nil {
93+ return err
94+ }
95+ }
96+
97+ return nil
2698}
2799
28100func processGzipFile (gzFile * os.File , process func (io.Reader ) error ) error {
@@ -48,10 +120,6 @@ extractTar extracts the contents of the given stream of bytes of a tar archive,
48120outputDir as the root of the extracted files.
49121*/
50122func extractTar (tarStream io.Reader , outputDir string ) error {
51- if outputDir == "" {
52- return fmt .Errorf ("outputDir is empty" )
53- }
54-
55123 tarReader := tar .NewReader (tarStream )
56124
57125 var header * tar.Header
@@ -107,3 +175,16 @@ func extractTar(tarStream io.Reader, outputDir string) error {
107175
108176 return nil
109177}
178+
179+ func detectFileType (archiveFile * os.File ) (string , error ) {
180+ // DetectContentType never uses more than the first 512 bytes.
181+ buffer := make ([]byte , 512 )
182+ _ , err := archiveFile .Read (buffer )
183+ if err != nil {
184+ return "" , err
185+ }
186+
187+ mimeType := http .DetectContentType (buffer )
188+
189+ return mimeType , nil
190+ }
0 commit comments