package web import ( "html/template" "io/fs" "log" "net" "net/http" "strconv" "strings" "time" "infinite-noodle/internal/assets" "infinite-noodle/internal/noodle" ) func StaticFiles() fs.FS { return assets.FS } type indexPageData struct { ClientIP string Noodles []noodle.Noodle } func HandleMain(db *noodle.Database, pc *chan noodle.Noodle) func(w http.ResponseWriter, req *http.Request) { tmpl, err := template.ParseFS(assets.FS, "templates/*.html") if err != nil { log.Fatalf("Error parsing templates: %v", err) } return func(w http.ResponseWriter, req *http.Request) { data := indexPageData{ ClientIP: clientIPFromRequest(req), Noodles: db.GetAll(), } if err := tmpl.ExecuteTemplate(w, "index.html", data); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } } func HandleDelete(db *noodle.Database, pc *chan noodle.Noodle) func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) { q := req.URL.Query() vals, ok := q["id"] if ok { id := vals[0] item := db.Get(id) item.IsUp = false *pc <- item log.Printf("Deleting noodle=%v", item) db.Delete(item.Id) } http.Redirect(w, req, "/", http.StatusTemporaryRedirect) } } func HandleAdd(db *noodle.Database, pc *chan noodle.Noodle) func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { http.Error(w, "", http.StatusMethodNotAllowed) return } if err := req.ParseForm(); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } listenPort, err := strconv.Atoi(strings.TrimSpace(req.FormValue("listen_port"))) if err != nil { http.Error(w, "invalid listen port", http.StatusBadRequest) return } destPort, err := strconv.Atoi(strings.TrimSpace(req.FormValue("dest_port"))) if err != nil { http.Error(w, "invalid destination port", http.StatusBadRequest) return } expiration, err := time.ParseDuration(strings.TrimSpace(req.FormValue("expiration"))) if err != nil { http.Error(w, "invalid expiration duration", http.StatusBadRequest) return } proto := strings.ToUpper(strings.TrimSpace(req.FormValue("proto"))) if proto != "TCP" && proto != "UDP" { http.Error(w, "invalid protocol", http.StatusBadRequest) return } clientIP := clientIPFromRequest(req) src := strings.TrimSpace(req.FormValue("src")) if src == clientIP { src = clientIP } if src != "All" && net.ParseIP(src) == nil { http.Error(w, "invalid source restriction", http.StatusBadRequest) return } item := noodle.Noodle{ Id: db.MakeID(), Name: strings.TrimSpace(req.FormValue("name")), Proto: proto, Src: src, ListenPort: listenPort, DestPort: destPort, DestHost: strings.TrimSpace(req.FormValue("dest_host")), Expiration: expiration, IsUp: true, } if err := db.Add(item); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } *pc <- item http.Redirect(w, req, "/", http.StatusTemporaryRedirect) } } func HandleToggle(db *noodle.Database, pc *chan noodle.Noodle) func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { http.Error(w, "", http.StatusMethodNotAllowed) return } if err := req.ParseForm(); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } id := strings.TrimSpace(req.FormValue("id")) if id == "" { http.Error(w, "missing noodle id", http.StatusBadRequest) return } item := db.Get(id) if item.Id == "" { http.Error(w, "noodle not found", http.StatusNotFound) return } item.IsUp = req.FormValue("is_up") == "on" if err := db.Update(item); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } *pc <- item http.Redirect(w, req, "/", http.StatusTemporaryRedirect) } } func clientIPFromRequest(req *http.Request) string { forwardedFor := strings.TrimSpace(req.Header.Get("X-Forwarded-For")) if forwardedFor != "" { parts := strings.Split(forwardedFor, ",") ip := strings.TrimSpace(parts[0]) if parsed := net.ParseIP(ip); parsed != nil { return parsed.String() } } realIP := strings.TrimSpace(req.Header.Get("X-Real-Ip")) if realIP != "" { if parsed := net.ParseIP(realIP); parsed != nil { return parsed.String() } } host, _, err := net.SplitHostPort(strings.TrimSpace(req.RemoteAddr)) if err == nil { if parsed := net.ParseIP(host); parsed != nil { return parsed.String() } } if parsed := net.ParseIP(strings.TrimSpace(req.RemoteAddr)); parsed != nil { return parsed.String() } return "127.0.0.1" }