diff --git a/integration/integration_test.go b/integration/integration_test.go index eaa24c0..96baa31 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -365,7 +365,13 @@ func TestMain(t *testing.T) { io.WriteString(w, fmt.Sprintf("invalid form files: %+v", form.File)) return } - file, err := form.File["file1"][0].Open() + header := form.File["file1"][0] + if ct := header.Header.Get("Content-Type"); ct != "image/jpeg" { + w.WriteHeader(400) + io.WriteString(w, fmt.Sprintf("invalid content-type: %q", ct)) + return + } + file, err := header.Open() if err != nil { w.WriteHeader(400) io.WriteString(w, "cannot open form file: "+err.Error()) @@ -374,7 +380,7 @@ func TestMain(t *testing.T) { var buf bytes.Buffer buf.ReadFrom(file) - if buf.String() != "file content" { + if buf.String() != "\xFF\xD8\xFF" { w.WriteHeader(400) io.WriteString(w, "invalid file content: "+buf.String()) return @@ -382,7 +388,7 @@ func TestMain(t *testing.T) { }) defer server.Close() - tempFile := createTempFile(t, "file content") + tempFile := createTempFile(t, "\xFF\xD8\xFF") // JPEG signature. defer os.Remove(tempFile) res := runFetch(t, fetchPath, server.URL, "-F", "key1=val1", "-F", "file1=@"+tempFile) assertExitCode(t, 0, res) diff --git a/internal/multipart/multipart.go b/internal/multipart/multipart.go index 6bd233f..5751336 100644 --- a/internal/multipart/multipart.go +++ b/internal/multipart/multipart.go @@ -1,9 +1,13 @@ package multipart import ( + "bytes" "io" "mime/multipart" + "net/http" + "net/textproto" "os" + "path/filepath" "strings" "github.com/ryanfowler/fetch/internal/core" @@ -41,20 +45,7 @@ func NewMultipart(kvs []core.KeyVal) *Multipart { } // Form part is a file. - w, err := mpw.CreateFormFile(kv.Key, kv.Val[1:]) - if err != nil { - writer.CloseWithError(err) - return - } - - f, err := os.Open(kv.Val[1:]) - if err != nil { - writer.CloseWithError(err) - return - } - - _, err = io.Copy(w, f) - f.Close() + err := writeFilePart(mpw, kv.Key, kv.Val[1:]) if err != nil { writer.CloseWithError(err) return @@ -72,3 +63,212 @@ func NewMultipart(kvs []core.KeyVal) *Multipart { func (m *Multipart) ContentType() string { return m.contentType } + +// writes the multipart file part and returns any error encountered. +func writeFilePart(mpw *multipart.Writer, key, filename string) error { + f, err := os.Open(filename) + if err != nil { + return err + } + defer f.Close() + + var r io.Reader = f + ct := detectTypeByExtension(filename) + if ct == "" { + // Unable to detect MIME type by extension, try from raw bytes. + sniff := make([]byte, 512) + n, err := f.Read(sniff) + if err != nil && err != io.EOF { + return err + } + + ct = http.DetectContentType(sniff[:n]) + r = io.MultiReader(bytes.NewReader(sniff[:n]), f) + } + + headers := textproto.MIMEHeader{} + headers.Set("Content-Disposition", multipart.FileContentDisposition(key, filename)) + headers.Set("Content-Type", ct) + + w, err := mpw.CreatePart(headers) + if err != nil { + return err + } + + _, err = io.Copy(w, r) + return err +} + +func detectTypeByExtension(filename string) string { + ext := strings.ToLower(filepath.Ext(filename)) + if ext == "" { + return "" + } + + switch ext { + // Images + case ".jpg", ".jpeg": + return "image/jpeg" + case ".png": + return "image/png" + case ".gif": + return "image/gif" + case ".webp": + return "image/webp" + case ".avif": + return "image/avif" + case ".heic", ".heif": + return "image/heif" + case ".jxl": + return "image/jxl" + case ".tif", ".tiff": + return "image/tiff" + case ".bmp": + return "image/bmp" + case ".ico": + return "image/x-icon" + case ".svg": + return "image/svg+xml" + case ".psd": + return "image/vnd.adobe.photoshop" + case ".raw", ".dng", ".nef", ".cr2", ".arw": + return "image/x-raw" + + // Video + case ".mp4": + return "video/mp4" + case ".m4v": + return "video/x-m4v" + case ".webm": + return "video/webm" + case ".mov": + return "video/quicktime" + case ".mkv": + return "video/x-matroska" + case ".avi": + return "video/x-msvideo" + case ".wmv": + return "video/x-ms-wmv" + case ".flv": + return "video/x-flv" + case ".mpeg", ".mpg": + return "video/mpeg" + case ".ogv": + return "video/ogg" + + // Audio + case ".mp3": + return "audio/mpeg" + case ".m4a": + return "audio/mp4" + case ".aac": + return "audio/aac" + case ".wav": + return "audio/wav" + case ".flac": + return "audio/flac" + case ".ogg": + return "audio/ogg" + case ".opus": + return "audio/opus" + case ".aiff", ".aif": + return "audio/aiff" + case ".mid", ".midi": + return "audio/midi" + + // Documents + case ".pdf": + return "application/pdf" + case ".txt": + return "text/plain; charset=utf-8" + case ".html", ".htm": + return "text/html; charset=utf-8" + case ".css": + return "text/css; charset=utf-8" + case ".csv": + return "text/csv; charset=utf-8" + case ".json": + return "application/json" + case ".xml": + return "application/xml" + case ".yaml", ".yml": + return "application/yaml" + case ".md": + return "text/markdown; charset=utf-8" + case ".rtf": + return "application/rtf" + + // Office formats + case ".doc": + return "application/msword" + case ".docx": + return "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + case ".xls": + return "application/vnd.ms-excel" + case ".xlsx": + return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + case ".ppt": + return "application/vnd.ms-powerpoint" + case ".pptx": + return "application/vnd.openxmlformats-officedocument.presentationml.presentation" + + // Fonts + case ".woff": + return "font/woff" + case ".woff2": + return "font/woff2" + case ".ttf": + return "font/ttf" + case ".otf": + return "font/otf" + case ".eot": + return "application/vnd.ms-fontobject" + + // Archives + case ".zip": + return "application/zip" + case ".tar": + return "application/x-tar" + case ".gz": + return "application/gzip" + case ".tgz": + return "application/gzip" + case ".bz2": + return "application/x-bzip2" + case ".xz": + return "application/x-xz" + case ".7z": + return "application/x-7z-compressed" + case ".rar": + return "application/vnd.rar" + + // Executables / binaries + case ".exe": + return "application/vnd.microsoft.portable-executable" + case ".msi": + return "application/x-msi" + case ".deb": + return "application/vnd.debian.binary-package" + case ".rpm": + return "application/x-rpm" + + // Scripts / code + case ".js": + return "application/javascript" + case ".mjs": + return "application/javascript" + case ".ts": + return "application/typescript" + case ".go": + return "text/x-go; charset=utf-8" + case ".rs": + return "text/x-rust; charset=utf-8" + case ".py": + return "text/x-python; charset=utf-8" + case ".sh": + return "application/x-sh" + + default: + return "" + } +} diff --git a/internal/multipart/multipart_test.go b/internal/multipart/multipart_test.go new file mode 100644 index 0000000..fdebec5 --- /dev/null +++ b/internal/multipart/multipart_test.go @@ -0,0 +1,129 @@ +package multipart + +import ( + "bytes" + "mime" + "mime/multipart" + "os" + "testing" + + "github.com/ryanfowler/fetch/internal/core" +) + +func TestMultipart(t *testing.T) { + tests := []struct { + name string + fnPre func(*testing.T) ([]core.KeyVal, func()) + fnPost func(*testing.T, *multipart.Form) + }{ + { + name: "small json file", + fnPre: func(t *testing.T) ([]core.KeyVal, func()) { + t.Helper() + + f, err := os.CreateTemp("", "*.json") + if err != nil { + t.Fatalf("unable to create temp file: %s", err.Error()) + } + defer f.Close() + f.WriteString(`{"key":"val"}`) + + return []core.KeyVal{{Key: "key1", Val: "@" + f.Name()}}, func() { + os.Remove(f.Name()) + } + }, + fnPost: func(t *testing.T, f *multipart.Form) { + t.Helper() + + header := f.File["key1"][0] + if ct := header.Header.Get("Content-Type"); ct != "application/json" { + t.Fatalf("unexpected content-type: %q", ct) + } + + file, err := header.Open() + if err != nil { + t.Fatalf("unable to open file: %s", err.Error()) + } + defer file.Close() + + var buf bytes.Buffer + buf.ReadFrom(file) + if buf.String() != `{"key":"val"}` { + t.Fatalf("unexpected file content: %q", buf.String()) + } + }, + }, + { + name: "file longer than 512 bytes with no extension", + fnPre: func(t *testing.T) ([]core.KeyVal, func()) { + t.Helper() + + f, err := os.CreateTemp("", "") + if err != nil { + t.Fatalf("unable to create temp file: %s", err.Error()) + } + defer f.Close() + + f.WriteString("\xFF\xD8\xFF") // JPEG signature. + f.Write(make([]byte, 512)) + + return []core.KeyVal{{Key: "key1", Val: "@" + f.Name()}}, func() { + os.Remove(f.Name()) + } + }, + fnPost: func(t *testing.T, f *multipart.Form) { + t.Helper() + + header := f.File["key1"][0] + if ct := header.Header.Get("Content-Type"); ct != "image/jpeg" { + t.Fatalf("unexpected content-type: %q", ct) + } + + file, err := header.Open() + if err != nil { + t.Fatalf("unable to open file: %s", err.Error()) + } + defer file.Close() + + var buf bytes.Buffer + buf.ReadFrom(file) + + var exp bytes.Buffer + exp.WriteString("\xFF\xD8\xFF") + exp.Write(make([]byte, 512)) + if !bytes.Equal(buf.Bytes(), exp.Bytes()) { + t.Fatalf("unexpected file content: %q", buf.String()) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + input, fn := test.fnPre(t) + if fn != nil { + defer fn() + } + + mp := NewMultipart(input) + + var buf bytes.Buffer + _, err := buf.ReadFrom(mp) + if err != nil { + t.Fatalf("unable to read from multipart: %s", err.Error()) + } + + _, params, err := mime.ParseMediaType(mp.ContentType()) + if err != nil { + t.Fatalf("unable to parse media type: %s", err.Error()) + } + + form, err := multipart.NewReader(&buf, params["boundary"]).ReadForm(1 << 24) + if err != nil { + t.Fatalf("unable to read multipart form: %s", err.Error()) + } + + test.fnPost(t, form) + }) + } +}