From 230e58c2865cf97c3cf5dbbe2c6caeb9efaf2958 Mon Sep 17 00:00:00 2001 From: vxclutch Date: Fri, 29 May 2026 08:05:00 -0400 Subject: [PATCH] feat: kill server after n downloads --- TODO | 1 - cmd/lash/main.go | 13 +++++++++-- internal/app/routes.go | 11 ++++++---- internal/handlers/file.go | 39 ++++++++++++++++++++++++--------- internal/shareLink/shareLink.go | 14 ++++++------ lash.go | 6 +++++ version | 2 +- 7 files changed, 61 insertions(+), 25 deletions(-) diff --git a/TODO b/TODO index 1872c5f..dfa64f5 100644 --- a/TODO +++ b/TODO @@ -5,5 +5,4 @@ feat: improve flags feat: replace uuid dep with custom id generator feat: multiple files fix: remove `_` prefix -fix: kill server after `n` downloads feat: use zeroconf for mDNS diff --git a/cmd/lash/main.go b/cmd/lash/main.go index 4d5571c..6044832 100644 --- a/cmd/lash/main.go +++ b/cmd/lash/main.go @@ -14,6 +14,7 @@ import ( var versionFlag = flag.Bool("version", false, "Print out version and exit.") var port = flag.Int("p", 1337, "Set the port for LASH exchanges.") +var numberOfShares = flag.Int("n", -1, "Number of shares to allow before killing the server.") func main() { flag.Parse() @@ -23,14 +24,22 @@ func main() { return } - srv := app.New() + ctx := lash.LashContext{ + N: *numberOfShares, + } + + srv := app.New(&ctx) // TODO(vxc): Make this more portable errx.Log("Your share link is %s", share.GenerateShareLink(*port)) errx.Log("Your token is \033[1;92m%s\033[0m", lash.Token) errx.Log("starting server at http://0.0.0.0:%d", *port) - if err := http.ListenAndServe(fmt.Sprintf(":%d", *port), srv); err != nil { + server := http.Server{ + Addr: fmt.Sprintf(":%d", *port), + Handler: srv, + } + if err := server.ListenAndServe(); err != nil { errx.FatalPerror(err) } } diff --git a/internal/app/routes.go b/internal/app/routes.go index ca75ca7..780d4d9 100644 --- a/internal/app/routes.go +++ b/internal/app/routes.go @@ -8,7 +8,7 @@ import ( "os" ) -func New() http.Handler { +func New(ctx *lash.LashContext) http.Handler { mux := http.NewServeMux() fp, err := GetFilePath() @@ -25,9 +25,12 @@ func New() http.Handler { Version: lash.Version, } - file := handlers.FileData{ - Contents: contents, - FileName: fp, + file := handlers.FileHandler{ + Ctx: ctx, + FileData: handlers.FileData{ + Contents: contents, + FileName: fp, + }, } mux.HandleFunc("/", share.Handler) diff --git a/internal/handlers/file.go b/internal/handlers/file.go index 53d802a..974152b 100644 --- a/internal/handlers/file.go +++ b/internal/handlers/file.go @@ -5,9 +5,15 @@ import ( "fmt" "lash" "net/http" + "os" "strconv" ) +type FileHandler struct { + Ctx *lash.LashContext + FileData FileData +} + type FileData struct { Contents []byte FileName string @@ -17,7 +23,14 @@ type ValidateRequest struct { Token string } -func (h FileData) APIHandler(w http.ResponseWriter, r *http.Request) { +var sent int = 0 + +func (h FileHandler) APIHandler(w http.ResponseWriter, r *http.Request) { + if sent >= h.Ctx.N && h.Ctx.N != -1 { + w.WriteHeader(http.StatusTooManyRequests) + os.Exit(0) + return + } decoder := json.NewDecoder(r.Body) var t ValidateRequest @@ -32,21 +45,27 @@ func (h FileData) APIHandler(w http.ResponseWriter, r *http.Request) { return } - // Send the file over as a stream of bytes - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", h.FileName)) + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", h.FileData.FileName)) w.Header().Set("Content-Type", "application/octet-stream") - w.Header().Set("Content-Length", strconv.Itoa(len(h.Contents))) + w.Header().Set("Content-Length", strconv.Itoa(len(h.FileData.Contents))) w.WriteHeader(http.StatusOK) - w.Write(h.Contents) + w.Write(h.FileData.Contents) + sent++ } -func (h FileData) FileHandler(w http.ResponseWriter, r *http.Request) { - // Send the file over as a stream of bytes - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", h.FileName)) +func (h FileHandler) FileHandler(w http.ResponseWriter, r *http.Request) { + if sent >= h.Ctx.N && h.Ctx.N != -1 { + w.WriteHeader(http.StatusTooManyRequests) + os.Exit(0) + return + } + + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", h.FileData.FileName)) w.Header().Set("Content-Type", "application/octet-stream") - w.Header().Set("Content-Length", strconv.Itoa(len(h.Contents))) + w.Header().Set("Content-Length", strconv.Itoa(len(h.FileData.Contents))) w.WriteHeader(http.StatusOK) - w.Write(h.Contents) + w.Write(h.FileData.Contents) + sent++ } diff --git a/internal/shareLink/shareLink.go b/internal/shareLink/shareLink.go index 2e04af6..0756014 100644 --- a/internal/shareLink/shareLink.go +++ b/internal/shareLink/shareLink.go @@ -17,13 +17,13 @@ func GenerateShareLink(port int) string { } func getLocalIP() string { - conn, err := net.Dial("udp", "8.8.8.8:80") - if err != nil { - errx.FatalPerror(err) - } - defer conn.Close() + conn, err := net.Dial("udp", "8.8.8.8:80") + if err != nil { + errx.FatalPerror(err) + } + defer conn.Close() - localAddress := conn.LocalAddr().(*net.UDPAddr) + localAddress := conn.LocalAddr().(*net.UDPAddr) - return localAddress.IP.String() + return localAddress.IP.String() } diff --git a/lash.go b/lash.go index 70741a8..12a594f 100644 --- a/lash.go +++ b/lash.go @@ -2,10 +2,16 @@ package lash import ( "embed" + "net/http" "github.com/google/uuid" ) +type LashContext struct { + N int + Server *http.Server +} + //go:embed templates/*.html var Templates embed.FS diff --git a/version b/version index 3eefcb9..00750ed 100644 --- a/version +++ b/version @@ -1 +1 @@ -1.0.0 +3