diff --git a/internal/app/app.go b/internal/app/app.go index 194407b..ec1f933 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -3,6 +3,7 @@ package app import ( "fmt" "log" + "net" "net/http" "time" @@ -31,10 +32,13 @@ func Run(cfg Config) error { } go systemCheck(db, noodleChannel) + go expirationCheck(db, noodleChannel) go tcpProxify(noodleChannel) http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(web.StaticFiles())))) http.HandleFunc("/", web.HandleMain(db, &noodleChannel)) + http.HandleFunc("/add", web.HandleAdd(db, &noodleChannel)) + http.HandleFunc("/toggle", web.HandleToggle(db, &noodleChannel)) http.HandleFunc("/delete", web.HandleDelete(db, &noodleChannel)) log.Printf("Server starting on %s", listenAddr) @@ -64,8 +68,11 @@ func tcpProxify(noodleChannel chan noodle.Noodle) { var proxy tcpproxy.Proxy src := fmt.Sprintf("0.0.0.0:%d", item.ListenPort) dst := fmt.Sprintf("%s:%d", item.DestHost, item.DestPort) - log.Printf("Starting a noodle from %s to %s", src, dst) - proxy.AddRoute(src, tcpproxy.To(dst)) + log.Printf("Starting a noodle from %s to %s with source=%s", src, dst, item.Src) + proxy.AddRoute(src, sourceRestrictedTarget{ + allowedIP: item.Src, + target: tcpproxy.To(dst), + }) noodleMap[item.Id] = &proxy go startProxy(&proxy) continue @@ -76,6 +83,78 @@ func tcpProxify(noodleChannel chan noodle.Noodle) { if err := noodleMap[item.Id].Close(); err != nil { log.Print(err) } + delete(noodleMap, item.Id) + } + } +} + +type sourceRestrictedTarget struct { + allowedIP string + target tcpproxy.Target +} + +func (t sourceRestrictedTarget) HandleConn(conn net.Conn) { + if t.allowedIP == "" || t.allowedIP == "All" { + t.target.HandleConn(conn) + return + } + + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + log.Printf("Rejected noodle connection with invalid remote address %q", conn.RemoteAddr().String()) + conn.Close() + return + } + if host != t.allowedIP { + log.Printf("Rejected noodle connection from %s; allowed source is %s", host, t.allowedIP) + conn.Close() + return + } + + t.target.HandleConn(conn) +} + +func expirationCheck(db *noodle.Database, noodleChannel chan noodle.Noodle) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for range ticker.C { + noodles := db.GetAll() + for _, item := range noodles { + if !item.IsUp { + continue + } + if item.Expiration <= 0 { + if item.IsUp { + item.IsUp = false + noodleChannel <- item + } + if err := db.Delete(item.Id); err != nil { + log.Print(err) + } + continue + } + + item.Expiration -= time.Second + if item.Expiration <= 0 { + item.Expiration = 0 + if err := db.Update(item); err != nil { + log.Print(err) + continue + } + if item.IsUp { + item.IsUp = false + noodleChannel <- item + } + if err := db.Delete(item.Id); err != nil { + log.Print(err) + } + continue + } + + if err := db.Update(item); err != nil { + log.Print(err) + } } } } @@ -86,10 +165,11 @@ func runTestSequence(db *noodle.Database) { Id: db.MakeID(), Name: "Name_Test", Proto: "Proto_Test", + Src: "All", ListenPort: 1080 + i, DestPort: 22, DestHost: "localhost", - Expiration: time.Now().Second(), + Expiration: time.Duration(time.Now().Second()) * time.Second, IsUp: true, } log.Printf("Test noodle=%v", item) diff --git a/internal/assets/templates/index.html b/internal/assets/templates/index.html index c404139..9100fab 100644 --- a/internal/assets/templates/index.html +++ b/internal/assets/templates/index.html @@ -43,6 +43,7 @@