initial version of the basic x11 proxy

This commit is contained in:
2025-08-02 13:07:49 -04:00
commit 02f07892ce
4 changed files with 382 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
.vscode
x11proxy
build.sh

116
auth.go Normal file
View File

@@ -0,0 +1,116 @@
package main
import (
"fmt"
"os"
"os/exec"
"path"
"strconv"
"strings"
"net"
)
func resolveDisplay(display string) (ConnType, string) {
var dispNum int
var err error
if strings.HasPrefix(display, ":") {
dispNum, err = strconv.Atoi(strings.Split(display[1:], ".")[0])
} else {
parts := strings.Split(display, ":")
dispNum, err = strconv.Atoi(strings.Split(parts[1], ".")[0])
}
if err != nil {
dispNum = 0
}
unixPath := path.Join("/tmp/.X11-unix", fmt.Sprintf("X%d", dispNum))
if _, err := os.Stat(unixPath); err == nil {
return Unix, unixPath
}
tcpAddr := fmt.Sprintf("127.0.0.1:%d", 6000+dispNum)
conn, err := net.Dial("tcp", tcpAddr)
if err == nil {
conn.Close()
return TCP, tcpAddr
}
return Unix, unixPath
}
func getXAuthCookie(display string) ([]byte, error) {
out, err := exec.Command("xauth", "list").Output()
if err != nil {
return nil, err
}
var dispNum string
if strings.HasPrefix(display, ":") {
dispNum = strings.Split(display[1:], ".")[0]
} else {
parts := strings.Split(display, ":")
dispNum = strings.Split(parts[1], ".")[0]
}
lines := strings.Split(string(out), "\n")
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" {
entry := fields[0]
if strings.Contains(entry, ":"+dispNum) {
fmt.Printf("Using XAuth cookie from entry: %s\n", entry)
return parseHexCookie(fields[2]), nil
}
}
}
return nil, fmt.Errorf("no matching cookie found for display %s", display)
}
func parseHexCookie(hexStr string) []byte {
var cookie []byte
for i := 0; i < len(hexStr); i += 2 {
var val byte
fmt.Sscanf(hexStr[i:i+2], "%02x", &val)
cookie = append(cookie, val)
}
return cookie
}
func PatchAuth(data []byte, cookie []byte) []byte {
isLittleEndian := data[0] == 'l'
var authProtoLen, authDataLen int
if isLittleEndian {
authProtoLen = int(data[7])<<8 | int(data[6])
authDataLen = int(data[9])<<8 | int(data[8])
} else {
authProtoLen = int(data[6])<<8 | int(data[7])
authDataLen = int(data[8])<<8 | int(data[9])
}
headerLen := 12
authProtoPad := (authProtoLen + 3) & ^3
authDataPad := (authDataLen + 3) & ^3
authDataStart := headerLen + authProtoPad
// Replace cookie and update length
patched := make([]byte, headerLen+authProtoPad+authDataPad)
copy(patched, data[:headerLen+authProtoPad])
copy(patched[authDataStart:], cookie)
// Update authDataLen to match cookie length
cookieLen := len(cookie)
if isLittleEndian {
patched[8] = byte(cookieLen)
patched[9] = byte(cookieLen >> 8)
} else {
patched[8] = byte(cookieLen >> 8)
patched[9] = byte(cookieLen)
}
return patched
}

72
main.go Normal file
View File

@@ -0,0 +1,72 @@
package main
import (
"flag"
"fmt"
"log"
"os"
"strings"
)
func main() {
display := os.Getenv("DISPLAY")
proxySocket := "/tmp/.X11-unix/X5"
overrideDisplay := flag.String("display", "", "Override DISPLAY")
overrideSocket := flag.String("proxy-socket", proxySocket, "Proxy socket path")
flag.Parse()
if *overrideDisplay != "" {
display = *overrideDisplay
}
connType, target := resolveDisplay(display)
fmt.Printf("Proxying to %s (%s)\n", target, connTypeString(connType))
err := StartProxy(*overrideSocket, target, connType, display)
if err != nil {
log.Fatalf("Proxy error: %v", err)
}
}
func connTypeString(t ConnType) string {
if t == Unix {
return "Unix socket"
}
return "TCP"
}
func hexDump(buf []byte) string {
var out strings.Builder
for i := 0; i < len(buf); i += 16 {
line := fmt.Sprintf("%08x ", i)
// Hex section
for j := 0; j < 16; j++ {
if i+j < len(buf) {
line += fmt.Sprintf("%02x ", buf[i+j])
} else {
line += " "
}
if j == 7 {
line += " " // extra space in middle
}
}
line += " |"
// ASCII section
for j := 0; j < 16 && i+j < len(buf); j++ {
b := buf[i+j]
if b >= 32 && b <= 126 {
line += string(b)
} else {
line += "."
}
}
line += "|\n"
out.WriteString(line)
}
return out.String()
}

191
proxy.go Normal file
View File

@@ -0,0 +1,191 @@
package main
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"os"
)
type ConnType int
const (
Unix ConnType = iota
TCP
)
func StartProxy(proxyPath, target string, connType ConnType, display string) error {
os.Remove(proxyPath)
listener, err := net.Listen("unix", proxyPath)
if err != nil {
return err
}
os.Chmod(proxyPath, 0700)
log.Printf("Proxy listening on %s", proxyPath)
for {
clientConn, err := listener.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go handleConnection(clientConn, connType, target, display)
}
}
func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
// Step 1: Read fixed-length header
header := make([]byte, 12)
if _, err := io.ReadFull(r, header); err != nil {
return nil, fmt.Errorf("failed to read handshake header: %w", err)
}
byteOrder := header[0]
isLittle := byteOrder == 'l'
// Step 2: Parse lengths
read16 := func(b []byte) uint16 {
if isLittle {
return binary.LittleEndian.Uint16(b)
}
return binary.BigEndian.Uint16(b)
}
authProtoLen := read16(header[6:8])
authDataLen := read16(header[8:10])
// Step 3: Read remaining fields
totalLen := int(authProtoLen+authDataLen)
totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen))
extra := make([]byte, totalLen)
if _, err := io.ReadFull(r, extra); err != nil {
return nil, fmt.Errorf("failed to read handshake auth fields: %w", err)
}
// Step 4: Decide whether to patch
patch := authProtoLen == 0 || authDataLen == 0
var authProto, authData []byte
if patch {
authProto = []byte("MIT-MAGIC-COOKIE-1")
authData = cookie
} else {
authProto = extra[:authProtoLen]
authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)]
}
// Step 5: Rebuild handshake
newAuthProtoLen := uint16(len(authProto))
newAuthDataLen := uint16(len(authData))
// Patch header in-place
if isLittle {
binary.LittleEndian.PutUint16(header[6:8], newAuthProtoLen)
binary.LittleEndian.PutUint16(header[8:10], newAuthDataLen)
} else {
binary.BigEndian.PutUint16(header[6:8], newAuthProtoLen)
binary.BigEndian.PutUint16(header[8:10], newAuthDataLen)
}
buf := bytes.NewBuffer(header)
buf.Write(authProto)
buf.Write(make([]byte, padLen(len(authProto))-len(authProto)))
buf.Write(authData)
buf.Write(make([]byte, padLen(len(authData))-len(authData)))
log.Printf("Rebuilt handshake: authProtoLen=%d, authDataLen=%d\n%s",
newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes()))
return buf.Bytes(), nil
}
func padLen(n int) int {
return (n + 3) &^ 3 // round up to next multiple of 4
}
func binaryOrder(isLittle bool) binary.ByteOrder {
if isLittle {
return binary.LittleEndian
}
return binary.BigEndian
}
func handleConnection(client net.Conn, connType ConnType, target string, display string) {
var serverConn net.Conn
var err error
if connType == Unix {
serverConn, err = net.Dial("unix", target)
} else {
serverConn, err = net.Dial("tcp", target)
}
if err != nil {
log.Printf("Failed to connect to target: %v", err)
client.Close()
return
}
log.Printf("New connection from %v", client.RemoteAddr())
cookie, err := getXAuthCookie(display)
if err != nil || cookie == nil {
log.Printf("Failed to get XAuth cookie")
client.Close()
serverConn.Close()
return
}
log.Printf("About to read the handshake")
patched, err := readAndPatchHandshake(client, cookie)
if err != nil {
log.Printf("Initial read failed: %v", err)
log.Fatal(err)
client.Close()
serverConn.Close()
return
}
_, err = serverConn.Write(patched)
if err != nil {
log.Printf("Initial write failed: %v", err)
client.Close()
serverConn.Close()
return
}
done := make(chan struct{}, 2)
go inspectAndForward(serverConn, client, "client→server", done)
go inspectAndForward(client, serverConn, "server→client", done)
<-done
client.Close()
serverConn.Close()
<-done
log.Printf("Connection closed: %v", client.RemoteAddr())
}
func inspectAndForward(dst net.Conn, src net.Conn, label string, done chan<- struct{}) {
buf := make([]byte, 4096)
for {
n, err := src.Read(buf)
if err != nil {
log.Printf("[%s] disconnected: %v", label, err)
break
}
if n > 0 {
//log.Printf("[%s] forwarded %d bytes", label, n)
_, err := dst.Write(buf[:n])
if err != nil {
log.Printf("[%s] write error: %v", label, err)
break
}
}
}
done <- struct{}{}
}