Files
X11Proxy/proxy.go
2025-08-03 01:24:21 -04:00

198 lines
4.8 KiB
Go

package main
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"os"
"time"
"x11proxy/x11probe"
)
func StartProxy(proxyPath, target string, connType x11probe.ConnType, display string) error {
os.Remove(proxyPath)
listener, err := net.Listen("unix", proxyPath)
if err != nil {
return err
}
os.Chmod(proxyPath, 0700)
log.Printf("Proxy listening on %s", proxyPath)
for {
clientConn, err := listener.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go handleConnection(clientConn, connType, target, display)
}
}
func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
// If r is a net.Conn, set a read deadline
if conn, ok := r.(net.Conn); ok {
deadline := time.Now().Add(3 * time.Second)
conn.SetReadDeadline(deadline)
}
// Step 1: Read fixed-length header manually
header := make([]byte, 12)
n, err := io.ReadFull(r, header)
if err != nil {
log.Printf("Handshake header read failed after %d bytes: %v", n, err)
return nil, fmt.Errorf("failed to read handshake header: %w", err)
}
byteOrder := header[0]
isLittle := byteOrder == 'l'
// Step 2: Parse lengths
read16 := func(b []byte) uint16 {
if isLittle {
return binary.LittleEndian.Uint16(b)
}
return binary.BigEndian.Uint16(b)
}
authProtoLen := read16(header[6:8])
authDataLen := read16(header[8:10])
// Step 3 A: Read remaining fields
totalLen := padLen(int(authProtoLen)) + padLen(int(authDataLen))
extra := make([]byte, totalLen)
n, err = io.ReadFull(r, extra)
if err != nil {
log.Printf("Handshake auth read failed after %d bytes: %v", n, err)
return nil, fmt.Errorf("failed to read handshake auth fields: %w", err)
}
// Step 3 B: clear timeout/deadline
if conn, ok := r.(net.Conn); ok {
conn.SetReadDeadline(time.Time{})
}
// Step 4: Decide whether to patch
patch := authProtoLen == 0 || authDataLen == 0
var authProto, authData []byte
if patch {
authProto = []byte("MIT-MAGIC-COOKIE-1")
authData = cookie
} else {
authProto = extra[:authProtoLen]
authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)]
}
// Step 5: Rebuild handshake
newAuthProtoLen := uint16(len(authProto))
newAuthDataLen := uint16(len(authData))
if isLittle {
binary.LittleEndian.PutUint16(header[6:8], newAuthProtoLen)
binary.LittleEndian.PutUint16(header[8:10], newAuthDataLen)
} else {
binary.BigEndian.PutUint16(header[6:8], newAuthProtoLen)
binary.BigEndian.PutUint16(header[8:10], newAuthDataLen)
}
buf := bytes.NewBuffer(header)
buf.Write(authProto)
buf.Write(make([]byte, padLen(len(authProto))-len(authProto)))
buf.Write(authData)
buf.Write(make([]byte, padLen(len(authData))-len(authData)))
log.Printf("Rebuilt handshake: authProtoLen=%d, authDataLen=%d\n%s",
newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes()))
return buf.Bytes(), nil
}
func padLen(n int) int {
return (n + 3) &^ 3 // round up to next multiple of 4
}
func binaryOrder(isLittle bool) binary.ByteOrder {
if isLittle {
return binary.LittleEndian
}
return binary.BigEndian
}
func handleConnection(client net.Conn, connType x11probe.ConnType, target string, display string) {
var serverConn net.Conn
var err error
var patch bool
if connType == x11probe.Unix {
serverConn, err = net.Dial("unix", target)
} else {
serverConn, err = net.Dial("tcp", target)
}
if err != nil {
log.Printf("Failed to connect to target: %v", err)
client.Close()
return
}
log.Printf("New connection from %v", client.RemoteAddr())
patch = true
cookie, err := getXAuthCookie(display)
if err != nil || cookie == nil {
patch = false
log.Printf("Failed to get XAuth cookie, try forwarding without patching")
}
if patch {
log.Printf("About to read the handshake")
patched, err := readAndPatchHandshake(client, cookie)
if err != nil {
log.Printf("Initial read failed: %v", err)
log.Fatal(err)
client.Close()
serverConn.Close()
return
}
_, err = serverConn.Write(patched)
if err != nil {
log.Printf("Initial write failed: %v", err)
client.Close()
serverConn.Close()
return
}
}
done := make(chan struct{}, 2)
go inspectAndForward(serverConn, client, "client→server", done)
go inspectAndForward(client, serverConn, "server→client", done)
<-done
client.Close()
serverConn.Close()
<-done
log.Printf("Connection closed: %v", client.RemoteAddr())
}
func inspectAndForward(dst net.Conn, src net.Conn, label string, done chan<- struct{}) {
buf := make([]byte, 4096)
for {
n, err := src.Read(buf)
if err != nil {
log.Printf("[%s] disconnected: %v", label, err)
break
}
if n > 0 {
//log.Printf("[%s] forwarded %d bytes", label, n)
_, err := dst.Write(buf[:n])
if err != nil {
log.Printf("[%s] write error: %v", label, err)
break
}
}
}
done <- struct{}{}
}