add X11 socket/port probe

This commit is contained in:
2025-08-02 22:23:57 -04:00
parent 6b87c9e688
commit 1795e91387
4 changed files with 422 additions and 279 deletions

175
auth.go
View File

@@ -1,116 +1,115 @@
package main package main
import ( import (
"fmt" "fmt"
"os"
"os/exec"
"path"
"strconv"
"strings"
"net" "net"
"os"
"os/exec"
"path"
"strconv"
"strings"
"x11proxy/x11probe"
) )
func resolveDisplay(display string) (ConnType, string) { func resolveDisplay(display string) (x11probe.ConnType, string) {
var dispNum int var dispNum int
var err error var err error
if strings.HasPrefix(display, ":") { if strings.HasPrefix(display, ":") {
dispNum, err = strconv.Atoi(strings.Split(display[1:], ".")[0]) dispNum, err = strconv.Atoi(strings.Split(display[1:], ".")[0])
} else { } else {
parts := strings.Split(display, ":") parts := strings.Split(display, ":")
dispNum, err = strconv.Atoi(strings.Split(parts[1], ".")[0]) dispNum, err = strconv.Atoi(strings.Split(parts[1], ".")[0])
} }
if err != nil { if err != nil {
dispNum = 0 dispNum = 0
} }
unixPath := path.Join("/tmp/.X11-unix", fmt.Sprintf("X%d", dispNum)) unixPath := path.Join("/tmp/.X11-unix", fmt.Sprintf("X%d", dispNum))
if _, err := os.Stat(unixPath); err == nil { if _, err := os.Stat(unixPath); err == nil {
return Unix, unixPath return x11probe.Unix, unixPath
} }
tcpAddr := fmt.Sprintf("127.0.0.1:%d", 6000+dispNum) tcpAddr := fmt.Sprintf("127.0.0.1:%d", 6000+dispNum)
conn, err := net.Dial("tcp", tcpAddr) conn, err := net.Dial("tcp", tcpAddr)
if err == nil { if err == nil {
conn.Close() conn.Close()
return TCP, tcpAddr return x11probe.TCP, tcpAddr
} }
return Unix, unixPath return x11probe.Unix, unixPath
} }
func getXAuthCookie(display string) ([]byte, error) { func getXAuthCookie(display string) ([]byte, error) {
out, err := exec.Command("xauth", "list").Output() out, err := exec.Command("xauth", "list").Output()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var dispNum string var dispNum string
if strings.HasPrefix(display, ":") { if strings.HasPrefix(display, ":") {
dispNum = strings.Split(display[1:], ".")[0] dispNum = strings.Split(display[1:], ".")[0]
} else { } else {
parts := strings.Split(display, ":") parts := strings.Split(display, ":")
dispNum = strings.Split(parts[1], ".")[0] dispNum = strings.Split(parts[1], ".")[0]
} }
lines := strings.Split(string(out), "\n") lines := strings.Split(string(out), "\n")
for _, line := range lines { for _, line := range lines {
fields := strings.Fields(line) fields := strings.Fields(line)
if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" { if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" {
entry := fields[0] entry := fields[0]
if strings.Contains(entry, ":"+dispNum) { if strings.Contains(entry, ":"+dispNum) {
fmt.Printf("Using XAuth cookie from entry: %s\n", entry) fmt.Printf("Using XAuth cookie from entry: %s\n", entry)
return parseHexCookie(fields[2]), nil return parseHexCookie(fields[2]), nil
} }
} }
} }
return nil, fmt.Errorf("no matching cookie found for display %s", display) return nil, fmt.Errorf("no matching cookie found for display %s", display)
} }
func parseHexCookie(hexStr string) []byte { func parseHexCookie(hexStr string) []byte {
var cookie []byte var cookie []byte
for i := 0; i < len(hexStr); i += 2 { for i := 0; i < len(hexStr); i += 2 {
var val byte var val byte
fmt.Sscanf(hexStr[i:i+2], "%02x", &val) fmt.Sscanf(hexStr[i:i+2], "%02x", &val)
cookie = append(cookie, val) cookie = append(cookie, val)
} }
return cookie return cookie
} }
func PatchAuth(data []byte, cookie []byte) []byte { func PatchAuth(data []byte, cookie []byte) []byte {
isLittleEndian := data[0] == 'l' isLittleEndian := data[0] == 'l'
var authProtoLen, authDataLen int var authProtoLen, authDataLen int
if isLittleEndian { if isLittleEndian {
authProtoLen = int(data[7])<<8 | int(data[6]) authProtoLen = int(data[7])<<8 | int(data[6])
authDataLen = int(data[9])<<8 | int(data[8]) authDataLen = int(data[9])<<8 | int(data[8])
} else { } else {
authProtoLen = int(data[6])<<8 | int(data[7]) authProtoLen = int(data[6])<<8 | int(data[7])
authDataLen = int(data[8])<<8 | int(data[9]) authDataLen = int(data[8])<<8 | int(data[9])
} }
headerLen := 12 headerLen := 12
authProtoPad := (authProtoLen + 3) & ^3 authProtoPad := (authProtoLen + 3) & ^3
authDataPad := (authDataLen + 3) & ^3 authDataPad := (authDataLen + 3) & ^3
authDataStart := headerLen + authProtoPad authDataStart := headerLen + authProtoPad
// Replace cookie and update length // Replace cookie and update length
patched := make([]byte, headerLen+authProtoPad+authDataPad) patched := make([]byte, headerLen+authProtoPad+authDataPad)
copy(patched, data[:headerLen+authProtoPad]) copy(patched, data[:headerLen+authProtoPad])
copy(patched[authDataStart:], cookie) copy(patched[authDataStart:], cookie)
// Update authDataLen to match cookie length // Update authDataLen to match cookie length
cookieLen := len(cookie) cookieLen := len(cookie)
if isLittleEndian { if isLittleEndian {
patched[8] = byte(cookieLen) patched[8] = byte(cookieLen)
patched[9] = byte(cookieLen >> 8) patched[9] = byte(cookieLen >> 8)
} else { } else {
patched[8] = byte(cookieLen >> 8) patched[8] = byte(cookieLen >> 8)
patched[9] = byte(cookieLen) patched[9] = byte(cookieLen)
} }
return patched return patched
} }

144
main.go
View File

@@ -1,87 +1,113 @@
package main package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"os" "os"
"strings" "strings"
"time"
"x11proxy/x11probe"
) )
const version = "0.1" const version = "0.1"
var ( var (
overrideDisplay = flag.String("display", "", "Override DISPLAY") overrideDisplay = flag.String("display", "", "Override DISPLAY")
overrideSocket = flag.String("proxy-socket", "/tmp/.X11-unix/X5", "Proxy socket path") overrideSocket = flag.String("proxy-socket", "/tmp/.X11-unix/X5", "Proxy socket path")
showVersion = flag.Bool("version", false, "Show version and exit") showVersion = flag.Bool("version", false, "Show version and exit")
) )
func init() { func init() {
flag.Usage = func() { flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage of %s:\n\n", os.Args[0])
flag.PrintDefaults() flag.PrintDefaults()
} }
} }
func main() { func main() {
flag.Parse() flag.Parse()
if *showVersion { if *showVersion {
fmt.Printf("X11Proxy version %s\n", version) fmt.Printf("X11Proxy version %s\n", version)
return return
} }
fmt.Printf("version %s\n", version) fmt.Printf("version %s\n", version)
display := os.Getenv("DISPLAY") display := os.Getenv("DISPLAY")
if *overrideDisplay != "" { if *overrideDisplay != "" {
display = *overrideDisplay display = *overrideDisplay
} }
connType, target := resolveDisplay(display) connType, target := resolveDisplay(display)
fmt.Printf("Proxying to %s (%s)\n", target, connTypeString(connType)) fmt.Printf("Proxying to %s (%s)\n", target, connTypeString(connType))
err := StartProxy(*overrideSocket, target, connType, display) timeout := 10 * time.Second
if err != nil { status, err := x11probe.ProbeX11Socket(connType, target, timeout)
log.Fatalf("Proxy error: %v", err) if err != nil {
} log.Fatalf("Connection probe error: %v", err)
}
switch status {
case x11probe.X11Dead:
log.Fatalf("Target connection is dead or unreachable")
case x11probe.X11AuthRequired:
fmt.Println("Authentication required for the target connection.")
case x11probe.X11Live:
fmt.Println("Connection to target is live and valid.")
}
switch status {
case x11probe.X11Live:
log.Println("X11 socket is live and accepting connections")
case x11probe.X11AuthRequired:
log.Println("X11 socket is live but requires authentication")
case x11probe.X11Dead:
log.Fatalf("Target connection is dead or unreachable")
}
err2 := StartProxy(*overrideSocket, target, connType, display)
if err2 != nil {
log.Fatalf("Proxy error: %v", err2)
}
} }
func connTypeString(t ConnType) string { func connTypeString(t x11probe.ConnType) string {
if t == Unix { if t == x11probe.Unix {
return "Unix socket" return "Unix socket"
} }
return "TCP" return "TCP"
} }
func hexDump(buf []byte) string { func hexDump(buf []byte) string {
var out strings.Builder var out strings.Builder
for i := 0; i < len(buf); i += 16 { for i := 0; i < len(buf); i += 16 {
line := fmt.Sprintf("%08x ", i) line := fmt.Sprintf("%08x ", i)
for j := 0; j < 16; j++ { for j := 0; j < 16; j++ {
if i+j < len(buf) { if i+j < len(buf) {
line += fmt.Sprintf("%02x ", buf[i+j]) line += fmt.Sprintf("%02x ", buf[i+j])
} else { } else {
line += " " line += " "
} }
if j == 7 { if j == 7 {
line += " " line += " "
} }
} }
line += " |" line += " |"
for j := 0; j < 16 && i+j < len(buf); j++ { for j := 0; j < 16 && i+j < len(buf); j++ {
b := buf[i+j] b := buf[i+j]
if b >= 32 && b <= 126 { if b >= 32 && b <= 126 {
line += string(b) line += string(b)
} else { } else {
line += "." line += "."
} }
} }
line += "|\n" line += "|\n"
out.WriteString(line) out.WriteString(line)
} }
return out.String() return out.String()
} }

255
proxy.go
View File

@@ -1,86 +1,80 @@
package main package main
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"os" "os"
"x11proxy/x11probe"
) )
type ConnType int func StartProxy(proxyPath, target string, connType x11probe.ConnType, display string) error {
os.Remove(proxyPath)
listener, err := net.Listen("unix", proxyPath)
if err != nil {
return err
}
os.Chmod(proxyPath, 0700)
const ( log.Printf("Proxy listening on %s", proxyPath)
Unix ConnType = iota
TCP
)
func StartProxy(proxyPath, target string, connType ConnType, display string) error { for {
os.Remove(proxyPath) clientConn, err := listener.Accept()
listener, err := net.Listen("unix", proxyPath) if err != nil {
if err != nil { log.Printf("Accept error: %v", err)
return err continue
} }
os.Chmod(proxyPath, 0700) go handleConnection(clientConn, connType, target, display)
}
log.Printf("Proxy listening on %s", proxyPath)
for {
clientConn, err := listener.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go handleConnection(clientConn, connType, target, display)
}
} }
func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) { func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
// Step 1: Read fixed-length header // Step 1: Read fixed-length header
header := make([]byte, 12) header := make([]byte, 12)
if _, err := io.ReadFull(r, header); err != nil { if _, err := io.ReadFull(r, header); err != nil {
return nil, fmt.Errorf("failed to read handshake header: %w", err) return nil, fmt.Errorf("failed to read handshake header: %w", err)
} }
byteOrder := header[0] byteOrder := header[0]
isLittle := byteOrder == 'l' isLittle := byteOrder == 'l'
// Step 2: Parse lengths // Step 2: Parse lengths
read16 := func(b []byte) uint16 { read16 := func(b []byte) uint16 {
if isLittle { if isLittle {
return binary.LittleEndian.Uint16(b) return binary.LittleEndian.Uint16(b)
} }
return binary.BigEndian.Uint16(b) return binary.BigEndian.Uint16(b)
} }
authProtoLen := read16(header[6:8]) authProtoLen := read16(header[6:8])
authDataLen := read16(header[8:10]) authDataLen := read16(header[8:10])
// Step 3: Read remaining fields // Step 3: Read remaining fields
totalLen := int(authProtoLen+authDataLen) totalLen := int(authProtoLen + authDataLen)
totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen)) totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen))
extra := make([]byte, totalLen) extra := make([]byte, totalLen)
if _, err := io.ReadFull(r, extra); err != nil { if _, err := io.ReadFull(r, extra); err != nil {
return nil, fmt.Errorf("failed to read handshake auth fields: %w", err) return nil, fmt.Errorf("failed to read handshake auth fields: %w", err)
} }
// Step 4: Decide whether to patch // Step 4: Decide whether to patch
patch := authProtoLen == 0 || authDataLen == 0 patch := authProtoLen == 0 || authDataLen == 0
var authProto, authData []byte var authProto, authData []byte
if patch { if patch {
authProto = []byte("MIT-MAGIC-COOKIE-1") authProto = []byte("MIT-MAGIC-COOKIE-1")
authData = cookie authData = cookie
} else { } else {
authProto = extra[:authProtoLen] authProto = extra[:authProtoLen]
authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)] authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)]
} }
// Step 5: Rebuild handshake // Step 5: Rebuild handshake
newAuthProtoLen := uint16(len(authProto)) newAuthProtoLen := uint16(len(authProto))
newAuthDataLen := uint16(len(authData)) newAuthDataLen := uint16(len(authData))
// Patch header in-place // Patch header in-place
if isLittle { if isLittle {
@@ -91,53 +85,52 @@ func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
binary.BigEndian.PutUint16(header[8:10], newAuthDataLen) binary.BigEndian.PutUint16(header[8:10], newAuthDataLen)
} }
buf := bytes.NewBuffer(header) buf := bytes.NewBuffer(header)
buf.Write(authProto) buf.Write(authProto)
buf.Write(make([]byte, padLen(len(authProto))-len(authProto))) buf.Write(make([]byte, padLen(len(authProto))-len(authProto)))
buf.Write(authData) buf.Write(authData)
buf.Write(make([]byte, padLen(len(authData))-len(authData))) buf.Write(make([]byte, padLen(len(authData))-len(authData)))
log.Printf("Rebuilt handshake: authProtoLen=%d, authDataLen=%d\n%s", log.Printf("Rebuilt handshake: authProtoLen=%d, authDataLen=%d\n%s",
newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes())) newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes()))
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func padLen(n int) int { func padLen(n int) int {
return (n + 3) &^ 3 // round up to next multiple of 4 return (n + 3) &^ 3 // round up to next multiple of 4
} }
func binaryOrder(isLittle bool) binary.ByteOrder { func binaryOrder(isLittle bool) binary.ByteOrder {
if isLittle { if isLittle {
return binary.LittleEndian return binary.LittleEndian
} }
return binary.BigEndian return binary.BigEndian
} }
func handleConnection(client net.Conn, connType x11probe.ConnType, target string, display string) {
var serverConn net.Conn
var err error
func handleConnection(client net.Conn, connType ConnType, target string, display string) { if connType == x11probe.Unix {
var serverConn net.Conn serverConn, err = net.Dial("unix", target)
var err error } else {
serverConn, err = net.Dial("tcp", target)
}
if err != nil {
log.Printf("Failed to connect to target: %v", err)
client.Close()
return
}
if connType == Unix { log.Printf("New connection from %v", client.RemoteAddr())
serverConn, err = net.Dial("unix", target)
} else {
serverConn, err = net.Dial("tcp", target)
}
if err != nil {
log.Printf("Failed to connect to target: %v", err)
client.Close()
return
}
log.Printf("New connection from %v", client.RemoteAddr()) cookie, err := getXAuthCookie(display)
if err != nil || cookie == nil {
cookie, err := getXAuthCookie(display) log.Printf("Failed to get XAuth cookie")
if err != nil || cookie == nil { client.Close()
log.Printf("Failed to get XAuth cookie") serverConn.Close()
client.Close() return
serverConn.Close() }
return
}
log.Printf("About to read the handshake") log.Printf("About to read the handshake")
patched, err := readAndPatchHandshake(client, cookie) patched, err := readAndPatchHandshake(client, cookie)
@@ -149,43 +142,43 @@ func handleConnection(client net.Conn, connType ConnType, target string, display
return return
} }
_, err = serverConn.Write(patched) _, err = serverConn.Write(patched)
if err != nil { if err != nil {
log.Printf("Initial write failed: %v", err) log.Printf("Initial write failed: %v", err)
client.Close() client.Close()
serverConn.Close() serverConn.Close()
return return
} }
done := make(chan struct{}, 2) done := make(chan struct{}, 2)
go inspectAndForward(serverConn, client, "client→server", done) go inspectAndForward(serverConn, client, "client→server", done)
go inspectAndForward(client, serverConn, "server→client", done) go inspectAndForward(client, serverConn, "server→client", done)
<-done <-done
client.Close() client.Close()
serverConn.Close() serverConn.Close()
<-done <-done
log.Printf("Connection closed: %v", client.RemoteAddr()) log.Printf("Connection closed: %v", client.RemoteAddr())
} }
func inspectAndForward(dst net.Conn, src net.Conn, label string, done chan<- struct{}) { func inspectAndForward(dst net.Conn, src net.Conn, label string, done chan<- struct{}) {
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
n, err := src.Read(buf) n, err := src.Read(buf)
if err != nil { if err != nil {
log.Printf("[%s] disconnected: %v", label, err) log.Printf("[%s] disconnected: %v", label, err)
break break
} }
if n > 0 { if n > 0 {
//log.Printf("[%s] forwarded %d bytes", label, n) //log.Printf("[%s] forwarded %d bytes", label, n)
_, err := dst.Write(buf[:n]) _, err := dst.Write(buf[:n])
if err != nil { if err != nil {
log.Printf("[%s] write error: %v", label, err) log.Printf("[%s] write error: %v", label, err)
break break
} }
} }
} }
done <- struct{}{} done <- struct{}{}
} }

125
x11probe/probe.go Normal file
View File

@@ -0,0 +1,125 @@
package x11probe
import (
"fmt"
"io"
"log"
"net"
"strings"
"time"
)
type ConnType int
const (
Unix ConnType = iota
TCP
)
// X11Status represents the result of probing an X11 socket.
type X11Status int
const (
X11Dead X11Status = iota
X11Live
X11AuthRequired
)
func ProbeX11Socket(connType ConnType, target string, timeout time.Duration) (X11Status, error) {
var (
conn net.Conn
err error
)
switch connType {
case Unix:
conn, err = net.DialTimeout("unix", target, timeout)
if err != nil {
return X11Dead, fmt.Errorf("unix dial failed: %w", err)
}
case TCP:
addr := parseTCPAddress(target)
conn, err = net.DialTimeout("tcp", addr, timeout)
if err != nil {
return X11Dead, fmt.Errorf("TCP dial failed: %w", err)
}
}
defer conn.Close()
log.Printf("Successfully connected to %s", target)
if err := performHandshake(conn); err != nil {
return X11Dead, err
}
status, err := readResponse(conn, timeout)
if err != nil {
return X11Dead, err
}
return status, nil
}
func parseTCPAddress(target string) string {
parts := strings.SplitN(target, ":", 2)
ipPart := parts[0]
portPart := "0"
if len(parts) == 2 && portPart != "" {
portPart = parts[1]
}
return fmt.Sprintf("%s:%s", ipPart, portPart)
}
func performHandshake(conn net.Conn) error {
// Minimal valid handshake with correct padding
handshake := []byte{
'l', 0, // Byte order
0, 11, 0, 0, // Protocol major/minor
0, 0, // Auth name length
0, 0, // Auth data length
0, 0, 0, 0, // Padding
}
if _, err := conn.Write(handshake); err != nil {
return fmt.Errorf("write failed: %w", err)
}
return nil
}
func readResponse(conn net.Conn, timeout time.Duration) (X11Status, error) {
buf := make([]byte, 1)
deadline := time.Now().Add(timeout)
if err := conn.SetReadDeadline(deadline); err != nil {
return X11Dead, fmt.Errorf("SetReadDeadline failed: %w", err)
}
n, err := conn.Read(buf)
if err != nil {
if err == io.EOF {
return X11AuthRequired, nil
}
return X11Dead, fmt.Errorf("read failed: %w", err)
}
if n == 0 {
return X11Dead, fmt.Errorf("read returned 0 bytes")
}
switch buf[0] {
case 0x01:
return X11Live, nil
case 0x02:
return X11AuthRequired, nil
case 0x00:
fmt.Printf("X11 handshake rejected (value %x)\n", buf[0])
return X11AuthRequired, nil
default:
fmt.Printf("unexpected value %x\n", buf[0])
return X11Dead, nil
}
}