diff --git a/main.go b/main.go index dee78fd..e065e3a 100644 --- a/main.go +++ b/main.go @@ -99,12 +99,12 @@ func Zeros(path string, size int64) error { return nil } -func NameGen() string { - const chars = "abcdefghjkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ123456789" +func NameGen(fileNameLength int) string { + const chars = "abcdefghjkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ1234567890-_" ll := len(chars) - b := make([]byte, conf.FileLen) + b := make([]byte, fileNameLength) rand.Read(b) // generates len(b) random bytes - for i := int64(0); i < int64(conf.FileLen); i++ { + for i := int64(0); i < int64(fileNameLength); i++ { b[i] = chars[int(b[i])%ll] } return string(b) @@ -120,8 +120,12 @@ func CheckFile(name string) bool { // false if doesn't exist, true if exists } func UploadHandler(w http.ResponseWriter, r *http.Request) { - // expiry sanitize - twentyfour := int64(conf.FileExpirySeconds) + var name string + var expiryTime int64 + var fileNameLength int + + fileNameLength = 0 + expiryTime = 0 file, _, err := r.FormFile("file") if err != nil { @@ -137,10 +141,47 @@ func UploadHandler(w http.ResponseWriter, r *http.Request) { } file.Seek(0, 0) + // Check if expiry time is present and length is too long + if r.PostFormValue("expiry") != "" { + expiryTime, err = strconv.ParseInt(r.PostFormValue("expiry"), 10, 64) + if err != nil { + log.Error().Err(err).Msg("expiry could not be parsed") + } else { + // 5 days max + if expiryTime < 1 || expiryTime > 432000 { + w.WriteHeader(http.StatusBadRequest) + return + } + } + } + + // Default to conf if not present + if expiryTime == 0 { + expiryTime = int64(conf.FileExpirySeconds) + } + + // Check if the file length parameter exists and also if it's too long + if r.PostFormValue("url_len") != "" { + fileNameLength, err = strconv.Atoi(r.PostFormValue("url_len")) + if err != nil { + log.Error().Err(err).Msg("url_len could not be parsed") + } else { + // if the length is < 3 and > 128 return error + if fileNameLength < 3 || fileNameLength > 128 { + w.WriteHeader(http.StatusBadRequest) + return + } + } + } + + // Default to conf if not present + if fileNameLength == 0 { + fileNameLength = conf.FileLen + } + // generate + check name - var name string for { - id := NameGen() + id := NameGen(fileNameLength) name = id + mtype.Extension() if !CheckFile(name) { break @@ -149,14 +190,14 @@ func UploadHandler(w http.ResponseWriter, r *http.Request) { err = db.Update(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("expiry")) - err := b.Put([]byte(name), []byte(strconv.FormatInt(time.Now().Unix()+twentyfour, 10))) + err := b.Put([]byte(name), []byte(strconv.FormatInt(time.Now().Unix()+expiryTime, 10))) return err }) if err != nil { log.Error().Err(err).Msg("Failed to put expiry") } - log.Info().Int64("expiry", twentyfour).Msg("Writing new file") + log.Info().Int64("expiry", expiryTime).Msg("Writing new file") f, err := os.OpenFile(conf.FileFolder+"/"+name, os.O_WRONLY|os.O_CREATE, 0644) if err != nil {