diff --git a/internal/app/app.go b/internal/app/app.go index 194407b..ec1f933 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -3,6 +3,7 @@ package app import ( "fmt" "log" + "net" "net/http" "time" @@ -31,10 +32,13 @@ func Run(cfg Config) error { } go systemCheck(db, noodleChannel) + go expirationCheck(db, noodleChannel) go tcpProxify(noodleChannel) http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(web.StaticFiles())))) http.HandleFunc("/", web.HandleMain(db, &noodleChannel)) + http.HandleFunc("/add", web.HandleAdd(db, &noodleChannel)) + http.HandleFunc("/toggle", web.HandleToggle(db, &noodleChannel)) http.HandleFunc("/delete", web.HandleDelete(db, &noodleChannel)) log.Printf("Server starting on %s", listenAddr) @@ -64,8 +68,11 @@ func tcpProxify(noodleChannel chan noodle.Noodle) { var proxy tcpproxy.Proxy src := fmt.Sprintf("0.0.0.0:%d", item.ListenPort) dst := fmt.Sprintf("%s:%d", item.DestHost, item.DestPort) - log.Printf("Starting a noodle from %s to %s", src, dst) - proxy.AddRoute(src, tcpproxy.To(dst)) + log.Printf("Starting a noodle from %s to %s with source=%s", src, dst, item.Src) + proxy.AddRoute(src, sourceRestrictedTarget{ + allowedIP: item.Src, + target: tcpproxy.To(dst), + }) noodleMap[item.Id] = &proxy go startProxy(&proxy) continue @@ -76,6 +83,78 @@ func tcpProxify(noodleChannel chan noodle.Noodle) { if err := noodleMap[item.Id].Close(); err != nil { log.Print(err) } + delete(noodleMap, item.Id) + } + } +} + +type sourceRestrictedTarget struct { + allowedIP string + target tcpproxy.Target +} + +func (t sourceRestrictedTarget) HandleConn(conn net.Conn) { + if t.allowedIP == "" || t.allowedIP == "All" { + t.target.HandleConn(conn) + return + } + + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + log.Printf("Rejected noodle connection with invalid remote address %q", conn.RemoteAddr().String()) + conn.Close() + return + } + if host != t.allowedIP { + log.Printf("Rejected noodle connection from %s; allowed source is %s", host, t.allowedIP) + conn.Close() + return + } + + t.target.HandleConn(conn) +} + +func expirationCheck(db *noodle.Database, noodleChannel chan noodle.Noodle) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for range ticker.C { + noodles := db.GetAll() + for _, item := range noodles { + if !item.IsUp { + continue + } + if item.Expiration <= 0 { + if item.IsUp { + item.IsUp = false + noodleChannel <- item + } + if err := db.Delete(item.Id); err != nil { + log.Print(err) + } + continue + } + + item.Expiration -= time.Second + if item.Expiration <= 0 { + item.Expiration = 0 + if err := db.Update(item); err != nil { + log.Print(err) + continue + } + if item.IsUp { + item.IsUp = false + noodleChannel <- item + } + if err := db.Delete(item.Id); err != nil { + log.Print(err) + } + continue + } + + if err := db.Update(item); err != nil { + log.Print(err) + } } } } @@ -86,10 +165,11 @@ func runTestSequence(db *noodle.Database) { Id: db.MakeID(), Name: "Name_Test", Proto: "Proto_Test", + Src: "All", ListenPort: 1080 + i, DestPort: 22, DestHost: "localhost", - Expiration: time.Now().Second(), + Expiration: time.Duration(time.Now().Second()) * time.Second, IsUp: true, } log.Printf("Test noodle=%v", item) diff --git a/internal/assets/templates/index.html b/internal/assets/templates/index.html index c404139..9100fab 100644 --- a/internal/assets/templates/index.html +++ b/internal/assets/templates/index.html @@ -43,6 +43,7 @@ Name Proto + Allow From Listening Port Dest Port Dest Host/IP @@ -53,34 +54,45 @@ - + TCP - - - - Expiration + + + + + + + + + + + - +
+ +
- {{range .}} + {{range .Noodles}} {{.Name}} {{.Proto}} + {{.Src}} {{.ListenPort}} {{.DestPort}} {{.DestHost}} - {{.Expiration}} + {{.Expiration}} - {{if .IsUp}} - - {{ else }} - - {{ end }} +
+ + +
@@ -99,6 +111,59 @@

+ - \ No newline at end of file + diff --git a/internal/noodle/database.go b/internal/noodle/database.go index 5821301..1624905 100644 --- a/internal/noodle/database.go +++ b/internal/noodle/database.go @@ -69,6 +69,14 @@ func (db *Database) Add(item Noodle) error { return nil } +func (db *Database) Update(item Noodle) error { + if err := db.Handle.Add(item.Id, item); err != nil { + log.Print(err) + return err + } + return nil +} + func (db *Database) Delete(id string) error { if err := db.Handle.Delete(id); err != nil { log.Print(err) diff --git a/internal/noodle/model.go b/internal/noodle/model.go index 4a53b49..6985f45 100644 --- a/internal/noodle/model.go +++ b/internal/noodle/model.go @@ -1,12 +1,15 @@ package noodle +import "time" + type Noodle struct { Id string Name string Proto string + Src string ListenPort int DestPort int DestHost string - Expiration int + Expiration time.Duration IsUp bool } diff --git a/internal/web/handlers.go b/internal/web/handlers.go index 7501199..1eeae98 100644 --- a/internal/web/handlers.go +++ b/internal/web/handlers.go @@ -4,7 +4,11 @@ import ( "html/template" "io/fs" "log" + "net" "net/http" + "strconv" + "strings" + "time" "infinite-noodle/internal/assets" "infinite-noodle/internal/noodle" @@ -14,13 +18,21 @@ func StaticFiles() fs.FS { return assets.FS } +type indexPageData struct { + ClientIP string + Noodles []noodle.Noodle +} + func HandleMain(db *noodle.Database, pc *chan noodle.Noodle) func(w http.ResponseWriter, req *http.Request) { tmpl, err := template.ParseFS(assets.FS, "templates/*.html") if err != nil { log.Fatalf("Error parsing templates: %v", err) } return func(w http.ResponseWriter, req *http.Request) { - data := db.GetAll() + data := indexPageData{ + ClientIP: clientIPFromRequest(req), + Noodles: db.GetAll(), + } if err := tmpl.ExecuteTemplate(w, "index.html", data); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } @@ -51,6 +63,122 @@ func HandleAdd(db *noodle.Database, pc *chan noodle.Noodle) func(w http.Response return } + if err := req.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + listenPort, err := strconv.Atoi(strings.TrimSpace(req.FormValue("listen_port"))) + if err != nil { + http.Error(w, "invalid listen port", http.StatusBadRequest) + return + } + + destPort, err := strconv.Atoi(strings.TrimSpace(req.FormValue("dest_port"))) + if err != nil { + http.Error(w, "invalid destination port", http.StatusBadRequest) + return + } + + expiration, err := time.ParseDuration(strings.TrimSpace(req.FormValue("expiration"))) + if err != nil { + http.Error(w, "invalid expiration duration", http.StatusBadRequest) + return + } + + clientIP := clientIPFromRequest(req) + src := strings.TrimSpace(req.FormValue("src")) + if src == clientIP { + src = clientIP + } + if src != "All" && net.ParseIP(src) == nil { + http.Error(w, "invalid source restriction", http.StatusBadRequest) + return + } + + item := noodle.Noodle{ + Id: db.MakeID(), + Name: strings.TrimSpace(req.FormValue("name")), + Proto: "TCP", + Src: src, + ListenPort: listenPort, + DestPort: destPort, + DestHost: strings.TrimSpace(req.FormValue("dest_host")), + Expiration: expiration, + IsUp: true, + } + if err := db.Add(item); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + *pc <- item + http.Redirect(w, req, "/", http.StatusTemporaryRedirect) } } + +func HandleToggle(db *noodle.Database, pc *chan noodle.Noodle) func(w http.ResponseWriter, req *http.Request) { + return func(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(w, "", http.StatusMethodNotAllowed) + return + } + + if err := req.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + id := strings.TrimSpace(req.FormValue("id")) + if id == "" { + http.Error(w, "missing noodle id", http.StatusBadRequest) + return + } + + item := db.Get(id) + if item.Id == "" { + http.Error(w, "noodle not found", http.StatusNotFound) + return + } + + item.IsUp = req.FormValue("is_up") == "on" + if err := db.Update(item); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + *pc <- item + + http.Redirect(w, req, "/", http.StatusTemporaryRedirect) + } +} + +func clientIPFromRequest(req *http.Request) string { + forwardedFor := strings.TrimSpace(req.Header.Get("X-Forwarded-For")) + if forwardedFor != "" { + parts := strings.Split(forwardedFor, ",") + ip := strings.TrimSpace(parts[0]) + if parsed := net.ParseIP(ip); parsed != nil { + return parsed.String() + } + } + + realIP := strings.TrimSpace(req.Header.Get("X-Real-Ip")) + if realIP != "" { + if parsed := net.ParseIP(realIP); parsed != nil { + return parsed.String() + } + } + + host, _, err := net.SplitHostPort(strings.TrimSpace(req.RemoteAddr)) + if err == nil { + if parsed := net.ParseIP(host); parsed != nil { + return parsed.String() + } + } + + if parsed := net.ParseIP(strings.TrimSpace(req.RemoteAddr)); parsed != nil { + return parsed.String() + } + + return "127.0.0.1" +}