From 1795e91387735d328fdbdfe1cb086a943305c7c0 Mon Sep 17 00:00:00 2001 From: Terrence Ezrol Date: Sat, 2 Aug 2025 22:23:57 -0400 Subject: [PATCH] add X11 socket/port probe --- auth.go | 175 ++++++++++++++++--------------- main.go | 144 +++++++++++++++----------- proxy.go | 257 ++++++++++++++++++++++------------------------ x11probe/probe.go | 125 ++++++++++++++++++++++ 4 files changed, 422 insertions(+), 279 deletions(-) create mode 100644 x11probe/probe.go diff --git a/auth.go b/auth.go index b103f72..bd5d316 100644 --- a/auth.go +++ b/auth.go @@ -1,116 +1,115 @@ package main import ( - "fmt" - "os" - "os/exec" - "path" - "strconv" - "strings" + "fmt" "net" + "os" + "os/exec" + "path" + "strconv" + "strings" + "x11proxy/x11probe" ) -func resolveDisplay(display string) (ConnType, string) { - var dispNum int - var err error +func resolveDisplay(display string) (x11probe.ConnType, string) { + var dispNum int + var err error - if strings.HasPrefix(display, ":") { - dispNum, err = strconv.Atoi(strings.Split(display[1:], ".")[0]) - } else { - parts := strings.Split(display, ":") - dispNum, err = strconv.Atoi(strings.Split(parts[1], ".")[0]) - } - if err != nil { - dispNum = 0 - } + if strings.HasPrefix(display, ":") { + dispNum, err = strconv.Atoi(strings.Split(display[1:], ".")[0]) + } else { + parts := strings.Split(display, ":") + dispNum, err = strconv.Atoi(strings.Split(parts[1], ".")[0]) + } + if err != nil { + dispNum = 0 + } - unixPath := path.Join("/tmp/.X11-unix", fmt.Sprintf("X%d", dispNum)) - if _, err := os.Stat(unixPath); err == nil { - return Unix, unixPath - } + unixPath := path.Join("/tmp/.X11-unix", fmt.Sprintf("X%d", dispNum)) + if _, err := os.Stat(unixPath); err == nil { + return x11probe.Unix, unixPath + } - tcpAddr := fmt.Sprintf("127.0.0.1:%d", 6000+dispNum) - conn, err := net.Dial("tcp", tcpAddr) - if err == nil { - conn.Close() - return TCP, tcpAddr - } + tcpAddr := fmt.Sprintf("127.0.0.1:%d", 6000+dispNum) + conn, err := net.Dial("tcp", tcpAddr) + if err == nil { + conn.Close() + return x11probe.TCP, tcpAddr + } - return Unix, unixPath + return x11probe.Unix, unixPath } func getXAuthCookie(display string) ([]byte, error) { - out, err := exec.Command("xauth", "list").Output() - if err != nil { - return nil, err - } + out, err := exec.Command("xauth", "list").Output() + if err != nil { + return nil, err + } - var dispNum string - if strings.HasPrefix(display, ":") { - dispNum = strings.Split(display[1:], ".")[0] - } else { - parts := strings.Split(display, ":") - dispNum = strings.Split(parts[1], ".")[0] - } + var dispNum string + if strings.HasPrefix(display, ":") { + dispNum = strings.Split(display[1:], ".")[0] + } else { + parts := strings.Split(display, ":") + dispNum = strings.Split(parts[1], ".")[0] + } - lines := strings.Split(string(out), "\n") - for _, line := range lines { - fields := strings.Fields(line) - if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" { - entry := fields[0] - if strings.Contains(entry, ":"+dispNum) { - fmt.Printf("Using XAuth cookie from entry: %s\n", entry) - return parseHexCookie(fields[2]), nil - } - } - } + lines := strings.Split(string(out), "\n") + for _, line := range lines { + fields := strings.Fields(line) + if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" { + entry := fields[0] + if strings.Contains(entry, ":"+dispNum) { + fmt.Printf("Using XAuth cookie from entry: %s\n", entry) + return parseHexCookie(fields[2]), nil + } + } + } - return nil, fmt.Errorf("no matching cookie found for display %s", display) + return nil, fmt.Errorf("no matching cookie found for display %s", display) } - func parseHexCookie(hexStr string) []byte { - var cookie []byte - for i := 0; i < len(hexStr); i += 2 { - var val byte - fmt.Sscanf(hexStr[i:i+2], "%02x", &val) - cookie = append(cookie, val) - } - return cookie + var cookie []byte + for i := 0; i < len(hexStr); i += 2 { + var val byte + fmt.Sscanf(hexStr[i:i+2], "%02x", &val) + cookie = append(cookie, val) + } + return cookie } func PatchAuth(data []byte, cookie []byte) []byte { - isLittleEndian := data[0] == 'l' + isLittleEndian := data[0] == 'l' - var authProtoLen, authDataLen int - if isLittleEndian { - authProtoLen = int(data[7])<<8 | int(data[6]) - authDataLen = int(data[9])<<8 | int(data[8]) - } else { - authProtoLen = int(data[6])<<8 | int(data[7]) - authDataLen = int(data[8])<<8 | int(data[9]) - } + var authProtoLen, authDataLen int + if isLittleEndian { + authProtoLen = int(data[7])<<8 | int(data[6]) + authDataLen = int(data[9])<<8 | int(data[8]) + } else { + authProtoLen = int(data[6])<<8 | int(data[7]) + authDataLen = int(data[8])<<8 | int(data[9]) + } - headerLen := 12 - authProtoPad := (authProtoLen + 3) & ^3 - authDataPad := (authDataLen + 3) & ^3 - authDataStart := headerLen + authProtoPad + headerLen := 12 + authProtoPad := (authProtoLen + 3) & ^3 + authDataPad := (authDataLen + 3) & ^3 + authDataStart := headerLen + authProtoPad - // Replace cookie and update length - patched := make([]byte, headerLen+authProtoPad+authDataPad) - copy(patched, data[:headerLen+authProtoPad]) - copy(patched[authDataStart:], cookie) + // Replace cookie and update length + patched := make([]byte, headerLen+authProtoPad+authDataPad) + copy(patched, data[:headerLen+authProtoPad]) + copy(patched[authDataStart:], cookie) - // Update authDataLen to match cookie length - cookieLen := len(cookie) - if isLittleEndian { - patched[8] = byte(cookieLen) - patched[9] = byte(cookieLen >> 8) - } else { - patched[8] = byte(cookieLen >> 8) - patched[9] = byte(cookieLen) - } + // Update authDataLen to match cookie length + cookieLen := len(cookie) + if isLittleEndian { + patched[8] = byte(cookieLen) + patched[9] = byte(cookieLen >> 8) + } else { + patched[8] = byte(cookieLen >> 8) + patched[9] = byte(cookieLen) + } - return patched + return patched } - diff --git a/main.go b/main.go index 72052d6..37d39b0 100644 --- a/main.go +++ b/main.go @@ -1,87 +1,113 @@ package main import ( - "flag" - "fmt" - "log" - "os" - "strings" + "flag" + "fmt" + "log" + "os" + "strings" + "time" + "x11proxy/x11probe" ) const version = "0.1" var ( - overrideDisplay = flag.String("display", "", "Override DISPLAY") - overrideSocket = flag.String("proxy-socket", "/tmp/.X11-unix/X5", "Proxy socket path") - showVersion = flag.Bool("version", false, "Show version and exit") + overrideDisplay = flag.String("display", "", "Override DISPLAY") + overrideSocket = flag.String("proxy-socket", "/tmp/.X11-unix/X5", "Proxy socket path") + showVersion = flag.Bool("version", false, "Show version and exit") ) func init() { - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage of %s:\n\n", os.Args[0]) - flag.PrintDefaults() - } + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n\n", os.Args[0]) + flag.PrintDefaults() + } } func main() { - flag.Parse() + flag.Parse() - if *showVersion { - fmt.Printf("X11Proxy version %s\n", version) - return - } - fmt.Printf("version %s\n", version) + if *showVersion { + fmt.Printf("X11Proxy version %s\n", version) + return + } + fmt.Printf("version %s\n", version) - display := os.Getenv("DISPLAY") - if *overrideDisplay != "" { - display = *overrideDisplay - } + display := os.Getenv("DISPLAY") + if *overrideDisplay != "" { + display = *overrideDisplay + } - connType, target := resolveDisplay(display) - fmt.Printf("Proxying to %s (%s)\n", target, connTypeString(connType)) + connType, target := resolveDisplay(display) + fmt.Printf("Proxying to %s (%s)\n", target, connTypeString(connType)) - err := StartProxy(*overrideSocket, target, connType, display) - if err != nil { - log.Fatalf("Proxy error: %v", err) - } + timeout := 10 * time.Second + status, err := x11probe.ProbeX11Socket(connType, target, timeout) + if err != nil { + log.Fatalf("Connection probe error: %v", err) + } + + switch status { + case x11probe.X11Dead: + log.Fatalf("Target connection is dead or unreachable") + case x11probe.X11AuthRequired: + fmt.Println("Authentication required for the target connection.") + case x11probe.X11Live: + fmt.Println("Connection to target is live and valid.") + } + + switch status { + case x11probe.X11Live: + log.Println("X11 socket is live and accepting connections") + case x11probe.X11AuthRequired: + log.Println("X11 socket is live but requires authentication") + case x11probe.X11Dead: + log.Fatalf("Target connection is dead or unreachable") + } + + err2 := StartProxy(*overrideSocket, target, connType, display) + if err2 != nil { + log.Fatalf("Proxy error: %v", err2) + } } -func connTypeString(t ConnType) string { - if t == Unix { - return "Unix socket" - } - return "TCP" +func connTypeString(t x11probe.ConnType) string { + if t == x11probe.Unix { + return "Unix socket" + } + return "TCP" } func hexDump(buf []byte) string { - var out strings.Builder - for i := 0; i < len(buf); i += 16 { - line := fmt.Sprintf("%08x ", i) + var out strings.Builder + for i := 0; i < len(buf); i += 16 { + line := fmt.Sprintf("%08x ", i) - for j := 0; j < 16; j++ { - if i+j < len(buf) { - line += fmt.Sprintf("%02x ", buf[i+j]) - } else { - line += " " - } - if j == 7 { - line += " " - } - } + for j := 0; j < 16; j++ { + if i+j < len(buf) { + line += fmt.Sprintf("%02x ", buf[i+j]) + } else { + line += " " + } + if j == 7 { + line += " " + } + } - line += " |" + line += " |" - for j := 0; j < 16 && i+j < len(buf); j++ { - b := buf[i+j] - if b >= 32 && b <= 126 { - line += string(b) - } else { - line += "." - } - } + for j := 0; j < 16 && i+j < len(buf); j++ { + b := buf[i+j] + if b >= 32 && b <= 126 { + line += string(b) + } else { + line += "." + } + } - line += "|\n" - out.WriteString(line) - } - return out.String() + line += "|\n" + out.WriteString(line) + } + return out.String() } diff --git a/proxy.go b/proxy.go index cb5f523..6a8ecab 100644 --- a/proxy.go +++ b/proxy.go @@ -1,86 +1,80 @@ package main import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "log" - "net" - "os" + "bytes" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "os" + "x11proxy/x11probe" ) -type ConnType int +func StartProxy(proxyPath, target string, connType x11probe.ConnType, display string) error { + os.Remove(proxyPath) + listener, err := net.Listen("unix", proxyPath) + if err != nil { + return err + } + os.Chmod(proxyPath, 0700) -const ( - Unix ConnType = iota - TCP -) + log.Printf("Proxy listening on %s", proxyPath) -func StartProxy(proxyPath, target string, connType ConnType, display string) error { - os.Remove(proxyPath) - listener, err := net.Listen("unix", proxyPath) - if err != nil { - return err - } - os.Chmod(proxyPath, 0700) - - log.Printf("Proxy listening on %s", proxyPath) - - for { - clientConn, err := listener.Accept() - if err != nil { - log.Printf("Accept error: %v", err) - continue - } - go handleConnection(clientConn, connType, target, display) - } + for { + clientConn, err := listener.Accept() + if err != nil { + log.Printf("Accept error: %v", err) + continue + } + go handleConnection(clientConn, connType, target, display) + } } func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) { - // Step 1: Read fixed-length header - header := make([]byte, 12) - if _, err := io.ReadFull(r, header); err != nil { - return nil, fmt.Errorf("failed to read handshake header: %w", err) - } + // Step 1: Read fixed-length header + header := make([]byte, 12) + if _, err := io.ReadFull(r, header); err != nil { + return nil, fmt.Errorf("failed to read handshake header: %w", err) + } - byteOrder := header[0] - isLittle := byteOrder == 'l' + byteOrder := header[0] + isLittle := byteOrder == 'l' - // Step 2: Parse lengths - read16 := func(b []byte) uint16 { - if isLittle { - return binary.LittleEndian.Uint16(b) - } - return binary.BigEndian.Uint16(b) - } + // Step 2: Parse lengths + read16 := func(b []byte) uint16 { + if isLittle { + return binary.LittleEndian.Uint16(b) + } + return binary.BigEndian.Uint16(b) + } - authProtoLen := read16(header[6:8]) - authDataLen := read16(header[8:10]) + authProtoLen := read16(header[6:8]) + authDataLen := read16(header[8:10]) - // Step 3: Read remaining fields - totalLen := int(authProtoLen+authDataLen) - totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen)) - extra := make([]byte, totalLen) - if _, err := io.ReadFull(r, extra); err != nil { - return nil, fmt.Errorf("failed to read handshake auth fields: %w", err) - } + // Step 3: Read remaining fields + totalLen := int(authProtoLen + authDataLen) + totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen)) + extra := make([]byte, totalLen) + if _, err := io.ReadFull(r, extra); err != nil { + return nil, fmt.Errorf("failed to read handshake auth fields: %w", err) + } - // Step 4: Decide whether to patch - patch := authProtoLen == 0 || authDataLen == 0 + // Step 4: Decide whether to patch + patch := authProtoLen == 0 || authDataLen == 0 - var authProto, authData []byte - if patch { - authProto = []byte("MIT-MAGIC-COOKIE-1") - authData = cookie - } else { - authProto = extra[:authProtoLen] - authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)] - } + var authProto, authData []byte + if patch { + authProto = []byte("MIT-MAGIC-COOKIE-1") + authData = cookie + } else { + authProto = extra[:authProtoLen] + authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)] + } - // Step 5: Rebuild handshake - newAuthProtoLen := uint16(len(authProto)) - newAuthDataLen := uint16(len(authData)) + // Step 5: Rebuild handshake + newAuthProtoLen := uint16(len(authProto)) + newAuthDataLen := uint16(len(authData)) // Patch header in-place if isLittle { @@ -91,53 +85,52 @@ func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) { binary.BigEndian.PutUint16(header[8:10], newAuthDataLen) } - buf := bytes.NewBuffer(header) - buf.Write(authProto) - buf.Write(make([]byte, padLen(len(authProto))-len(authProto))) - buf.Write(authData) - buf.Write(make([]byte, padLen(len(authData))-len(authData))) + buf := bytes.NewBuffer(header) + buf.Write(authProto) + buf.Write(make([]byte, padLen(len(authProto))-len(authProto))) + buf.Write(authData) + buf.Write(make([]byte, padLen(len(authData))-len(authData))) log.Printf("Rebuilt handshake: authProtoLen=%d, authDataLen=%d\n%s", - newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes())) - return buf.Bytes(), nil + newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes())) + return buf.Bytes(), nil } func padLen(n int) int { - return (n + 3) &^ 3 // round up to next multiple of 4 + return (n + 3) &^ 3 // round up to next multiple of 4 } func binaryOrder(isLittle bool) binary.ByteOrder { - if isLittle { - return binary.LittleEndian - } - return binary.BigEndian + if isLittle { + return binary.LittleEndian + } + return binary.BigEndian } +func handleConnection(client net.Conn, connType x11probe.ConnType, target string, display string) { + var serverConn net.Conn + var err error -func handleConnection(client net.Conn, connType ConnType, target string, display string) { - var serverConn net.Conn - var err error + if connType == x11probe.Unix { + serverConn, err = net.Dial("unix", target) + } else { + serverConn, err = net.Dial("tcp", target) + } + if err != nil { + log.Printf("Failed to connect to target: %v", err) + client.Close() + return + } - if connType == Unix { - serverConn, err = net.Dial("unix", target) - } else { - serverConn, err = net.Dial("tcp", target) - } - if err != nil { - log.Printf("Failed to connect to target: %v", err) - client.Close() - return - } + log.Printf("New connection from %v", client.RemoteAddr()) - log.Printf("New connection from %v", client.RemoteAddr()) - - cookie, err := getXAuthCookie(display) - if err != nil || cookie == nil { - log.Printf("Failed to get XAuth cookie") - client.Close() - serverConn.Close() - return - } + cookie, err := getXAuthCookie(display) + if err != nil || cookie == nil { + log.Printf("Failed to get XAuth cookie") + client.Close() + serverConn.Close() + return + } log.Printf("About to read the handshake") patched, err := readAndPatchHandshake(client, cookie) @@ -148,44 +141,44 @@ func handleConnection(client net.Conn, connType ConnType, target string, display serverConn.Close() return } - - _, err = serverConn.Write(patched) - if err != nil { - log.Printf("Initial write failed: %v", err) - client.Close() - serverConn.Close() - return - } - done := make(chan struct{}, 2) + _, err = serverConn.Write(patched) + if err != nil { + log.Printf("Initial write failed: %v", err) + client.Close() + serverConn.Close() + return + } - go inspectAndForward(serverConn, client, "client→server", done) - go inspectAndForward(client, serverConn, "server→client", done) + done := make(chan struct{}, 2) - <-done - client.Close() - serverConn.Close() - <-done + go inspectAndForward(serverConn, client, "client→server", done) + go inspectAndForward(client, serverConn, "server→client", done) - log.Printf("Connection closed: %v", client.RemoteAddr()) + <-done + client.Close() + serverConn.Close() + <-done + + log.Printf("Connection closed: %v", client.RemoteAddr()) } func inspectAndForward(dst net.Conn, src net.Conn, label string, done chan<- struct{}) { - buf := make([]byte, 4096) - for { - n, err := src.Read(buf) - if err != nil { - log.Printf("[%s] disconnected: %v", label, err) - break - } - if n > 0 { + buf := make([]byte, 4096) + for { + n, err := src.Read(buf) + if err != nil { + log.Printf("[%s] disconnected: %v", label, err) + break + } + if n > 0 { //log.Printf("[%s] forwarded %d bytes", label, n) - _, err := dst.Write(buf[:n]) - if err != nil { - log.Printf("[%s] write error: %v", label, err) - break - } - } - } - done <- struct{}{} + _, err := dst.Write(buf[:n]) + if err != nil { + log.Printf("[%s] write error: %v", label, err) + break + } + } + } + done <- struct{}{} } diff --git a/x11probe/probe.go b/x11probe/probe.go new file mode 100644 index 0000000..4d593bf --- /dev/null +++ b/x11probe/probe.go @@ -0,0 +1,125 @@ +package x11probe + +import ( + "fmt" + "io" + "log" + "net" + "strings" + "time" +) + +type ConnType int + +const ( + Unix ConnType = iota + TCP +) + +// X11Status represents the result of probing an X11 socket. +type X11Status int + +const ( + X11Dead X11Status = iota + X11Live + X11AuthRequired +) + +func ProbeX11Socket(connType ConnType, target string, timeout time.Duration) (X11Status, error) { + var ( + conn net.Conn + err error + ) + + switch connType { + case Unix: + conn, err = net.DialTimeout("unix", target, timeout) + if err != nil { + return X11Dead, fmt.Errorf("unix dial failed: %w", err) + } + case TCP: + addr := parseTCPAddress(target) + conn, err = net.DialTimeout("tcp", addr, timeout) + if err != nil { + return X11Dead, fmt.Errorf("TCP dial failed: %w", err) + } + } + + defer conn.Close() + log.Printf("Successfully connected to %s", target) + + if err := performHandshake(conn); err != nil { + return X11Dead, err + } + + status, err := readResponse(conn, timeout) + if err != nil { + return X11Dead, err + } + + return status, nil +} + +func parseTCPAddress(target string) string { + parts := strings.SplitN(target, ":", 2) + + ipPart := parts[0] + portPart := "0" + + if len(parts) == 2 && portPart != "" { + portPart = parts[1] + } + + return fmt.Sprintf("%s:%s", ipPart, portPart) +} + +func performHandshake(conn net.Conn) error { + // Minimal valid handshake with correct padding + handshake := []byte{ + 'l', 0, // Byte order + 0, 11, 0, 0, // Protocol major/minor + 0, 0, // Auth name length + 0, 0, // Auth data length + 0, 0, 0, 0, // Padding + } + + if _, err := conn.Write(handshake); err != nil { + return fmt.Errorf("write failed: %w", err) + } + return nil +} + +func readResponse(conn net.Conn, timeout time.Duration) (X11Status, error) { + buf := make([]byte, 1) + + deadline := time.Now().Add(timeout) + if err := conn.SetReadDeadline(deadline); err != nil { + return X11Dead, fmt.Errorf("SetReadDeadline failed: %w", err) + } + + n, err := conn.Read(buf) + if err != nil { + if err == io.EOF { + return X11AuthRequired, nil + } + return X11Dead, fmt.Errorf("read failed: %w", err) + } + + if n == 0 { + return X11Dead, fmt.Errorf("read returned 0 bytes") + } + + switch buf[0] { + case 0x01: + return X11Live, nil + case 0x02: + return X11AuthRequired, nil + case 0x00: + fmt.Printf("X11 handshake rejected (value %x)\n", buf[0]) + return X11AuthRequired, nil + default: + fmt.Printf("unexpected value %x\n", buf[0]) + return X11Dead, nil + } + +}