diff --git a/auth.go b/auth.go index bd5d316..424975c 100644 --- a/auth.go +++ b/auth.go @@ -41,11 +41,7 @@ func resolveDisplay(display string) (x11probe.ConnType, string) { } func getXAuthCookie(display string) ([]byte, error) { - out, err := exec.Command("xauth", "list").Output() - if err != nil { - return nil, err - } - + // Extract display number var dispNum string if strings.HasPrefix(display, ":") { dispNum = strings.Split(display[1:], ".")[0] @@ -54,16 +50,51 @@ func getXAuthCookie(display string) ([]byte, error) { dispNum = strings.Split(parts[1], ".")[0] } + // Query only entries for that display + out, err := exec.Command("xauth", "list", ":"+dispNum).Output() + if err != nil { + return nil, fmt.Errorf("xauth list failed: %w", err) + } + lines := strings.Split(string(out), "\n") + var bestMatch string + var bestCookie string + var bestScore int + for _, line := range lines { fields := strings.Fields(line) - if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" { - entry := fields[0] - if strings.Contains(entry, ":"+dispNum) { - fmt.Printf("Using XAuth cookie from entry: %s\n", entry) - return parseHexCookie(fields[2]), nil - } + if len(fields) < 3 || fields[1] != "MIT-MAGIC-COOKIE-1" { + continue } + + entry := fields[0] + cookie := fields[2] + + // Score based on specificity + score := 0 + switch { + case strings.HasPrefix(entry, os.Getenv("HOSTNAME")+"/unix:"): + score = 100 + case strings.HasPrefix(entry, "unix:"): + score = 90 + case strings.Contains(entry, "/unix:"): + score = 80 + case strings.HasPrefix(entry, "#ffff#"): + score = 70 + default: + score = 50 + } + + if score > bestScore { + bestMatch = entry + bestCookie = cookie + bestScore = score + } + } + + if bestCookie != "" { + fmt.Printf("Using XAuth cookie from entry: %s\n", bestMatch) + return parseHexCookie(bestCookie), nil } return nil, fmt.Errorf("no matching cookie found for display %s", display) diff --git a/main.go b/main.go index 37d39b0..6abfa95 100644 --- a/main.go +++ b/main.go @@ -48,15 +48,6 @@ func main() { log.Fatalf("Connection probe error: %v", err) } - switch status { - case x11probe.X11Dead: - log.Fatalf("Target connection is dead or unreachable") - case x11probe.X11AuthRequired: - fmt.Println("Authentication required for the target connection.") - case x11probe.X11Live: - fmt.Println("Connection to target is live and valid.") - } - switch status { case x11probe.X11Live: log.Println("X11 socket is live and accepting connections") diff --git a/proxy.go b/proxy.go index 6a8ecab..653254e 100644 --- a/proxy.go +++ b/proxy.go @@ -8,6 +8,7 @@ import ( "log" "net" "os" + "time" "x11proxy/x11probe" ) @@ -30,11 +31,18 @@ func StartProxy(proxyPath, target string, connType x11probe.ConnType, display st go handleConnection(clientConn, connType, target, display) } } - func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) { - // Step 1: Read fixed-length header + // If r is a net.Conn, set a read deadline + if conn, ok := r.(net.Conn); ok { + deadline := time.Now().Add(3 * time.Second) + conn.SetReadDeadline(deadline) + } + + // Step 1: Read fixed-length header manually header := make([]byte, 12) - if _, err := io.ReadFull(r, header); err != nil { + n, err := io.ReadFull(r, header) + if err != nil { + log.Printf("Handshake header read failed after %d bytes: %v", n, err) return nil, fmt.Errorf("failed to read handshake header: %w", err) } @@ -52,13 +60,19 @@ func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) { authProtoLen := read16(header[6:8]) authDataLen := read16(header[8:10]) - // Step 3: Read remaining fields - totalLen := int(authProtoLen + authDataLen) - totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen)) + // Step 3 A: Read remaining fields + totalLen := padLen(int(authProtoLen)) + padLen(int(authDataLen)) extra := make([]byte, totalLen) - if _, err := io.ReadFull(r, extra); err != nil { + + n, err = io.ReadFull(r, extra) + if err != nil { + log.Printf("Handshake auth read failed after %d bytes: %v", n, err) return nil, fmt.Errorf("failed to read handshake auth fields: %w", err) } + // Step 3 B: clear timeout/deadline + if conn, ok := r.(net.Conn); ok { + conn.SetReadDeadline(time.Time{}) + } // Step 4: Decide whether to patch patch := authProtoLen == 0 || authDataLen == 0 @@ -76,7 +90,6 @@ func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) { newAuthProtoLen := uint16(len(authProto)) newAuthDataLen := uint16(len(authData)) - // Patch header in-place if isLittle { binary.LittleEndian.PutUint16(header[6:8], newAuthProtoLen) binary.LittleEndian.PutUint16(header[8:10], newAuthDataLen) @@ -110,6 +123,7 @@ func binaryOrder(isLittle bool) binary.ByteOrder { func handleConnection(client net.Conn, connType x11probe.ConnType, target string, display string) { var serverConn net.Conn var err error + var patch bool if connType == x11probe.Unix { serverConn, err = net.Dial("unix", target) @@ -123,33 +137,32 @@ func handleConnection(client net.Conn, connType x11probe.ConnType, target string } log.Printf("New connection from %v", client.RemoteAddr()) + patch = true cookie, err := getXAuthCookie(display) if err != nil || cookie == nil { - log.Printf("Failed to get XAuth cookie") - client.Close() - serverConn.Close() - return + patch = false + log.Printf("Failed to get XAuth cookie, try forwarding without patching") } - log.Printf("About to read the handshake") - patched, err := readAndPatchHandshake(client, cookie) - if err != nil { - log.Printf("Initial read failed: %v", err) - log.Fatal(err) - client.Close() - serverConn.Close() - return + if patch { + log.Printf("About to read the handshake") + patched, err := readAndPatchHandshake(client, cookie) + if err != nil { + log.Printf("Initial read failed: %v", err) + log.Fatal(err) + client.Close() + serverConn.Close() + return + } + _, err = serverConn.Write(patched) + if err != nil { + log.Printf("Initial write failed: %v", err) + client.Close() + serverConn.Close() + return + } } - - _, err = serverConn.Write(patched) - if err != nil { - log.Printf("Initial write failed: %v", err) - client.Close() - serverConn.Close() - return - } - done := make(chan struct{}, 2) go inspectAndForward(serverConn, client, "client→server", done)