package main import ( "bytes" "encoding/binary" "fmt" "io" "log" "net" "os" "time" "x11proxy/x11probe" ) func StartProxy(proxyPath, target string, connType x11probe.ConnType, display string) error { os.Remove(proxyPath) listener, err := net.Listen("unix", proxyPath) if err != nil { return err } os.Chmod(proxyPath, 0700) log.Printf("Proxy listening on %s", proxyPath) for { clientConn, err := listener.Accept() if err != nil { log.Printf("Accept error: %v", err) continue } go handleConnection(clientConn, connType, target, display) } } func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) { // If r is a net.Conn, set a read deadline if conn, ok := r.(net.Conn); ok { deadline := time.Now().Add(3 * time.Second) conn.SetReadDeadline(deadline) } // Step 1: Read fixed-length header manually header := make([]byte, 12) n, err := io.ReadFull(r, header) if err != nil { log.Printf("Handshake header read failed after %d bytes: %v", n, err) return nil, fmt.Errorf("failed to read handshake header: %w", err) } byteOrder := header[0] isLittle := byteOrder == 'l' // Step 2: Parse lengths read16 := func(b []byte) uint16 { if isLittle { return binary.LittleEndian.Uint16(b) } return binary.BigEndian.Uint16(b) } authProtoLen := read16(header[6:8]) authDataLen := read16(header[8:10]) // Step 3 A: Read remaining fields totalLen := padLen(int(authProtoLen)) + padLen(int(authDataLen)) extra := make([]byte, totalLen) n, err = io.ReadFull(r, extra) if err != nil { log.Printf("Handshake auth read failed after %d bytes: %v", n, err) return nil, fmt.Errorf("failed to read handshake auth fields: %w", err) } // Step 3 B: clear timeout/deadline if conn, ok := r.(net.Conn); ok { conn.SetReadDeadline(time.Time{}) } // Step 4: Decide whether to patch patch := authProtoLen == 0 || authDataLen == 0 var authProto, authData []byte if patch { authProto = []byte("MIT-MAGIC-COOKIE-1") authData = cookie } else { authProto = extra[:authProtoLen] authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)] } // Step 5: Rebuild handshake newAuthProtoLen := uint16(len(authProto)) newAuthDataLen := uint16(len(authData)) if isLittle { binary.LittleEndian.PutUint16(header[6:8], newAuthProtoLen) binary.LittleEndian.PutUint16(header[8:10], newAuthDataLen) } else { binary.BigEndian.PutUint16(header[6:8], newAuthProtoLen) binary.BigEndian.PutUint16(header[8:10], newAuthDataLen) } buf := bytes.NewBuffer(header) buf.Write(authProto) buf.Write(make([]byte, padLen(len(authProto))-len(authProto))) buf.Write(authData) buf.Write(make([]byte, padLen(len(authData))-len(authData))) log.Printf("Rebuilt handshake: authProtoLen=%d, authDataLen=%d\n%s", newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes())) return buf.Bytes(), nil } func padLen(n int) int { return (n + 3) &^ 3 // round up to next multiple of 4 } func binaryOrder(isLittle bool) binary.ByteOrder { if isLittle { return binary.LittleEndian } return binary.BigEndian } func handleConnection(client net.Conn, connType x11probe.ConnType, target string, display string) { var serverConn net.Conn var err error var patch bool if connType == x11probe.Unix { serverConn, err = net.Dial("unix", target) } else { serverConn, err = net.Dial("tcp", target) } if err != nil { log.Printf("Failed to connect to target: %v", err) client.Close() return } log.Printf("New connection from %v", client.RemoteAddr()) patch = true cookie, err := getXAuthCookie(display) if err != nil || cookie == nil { patch = false log.Printf("Failed to get XAuth cookie, try forwarding without patching") } if patch { log.Printf("About to read the handshake") patched, err := readAndPatchHandshake(client, cookie) if err != nil { log.Printf("Initial read failed: %v", err) log.Fatal(err) client.Close() serverConn.Close() return } _, err = serverConn.Write(patched) if err != nil { log.Printf("Initial write failed: %v", err) client.Close() serverConn.Close() return } } done := make(chan struct{}, 2) go inspectAndForward(serverConn, client, "client→server", done) go inspectAndForward(client, serverConn, "server→client", done) <-done client.Close() serverConn.Close() <-done log.Printf("Connection closed: %v", client.RemoteAddr()) } func inspectAndForward(dst net.Conn, src net.Conn, label string, done chan<- struct{}) { buf := make([]byte, 4096) for { n, err := src.Read(buf) if err != nil { log.Printf("[%s] disconnected: %v", label, err) break } if n > 0 { //log.Printf("[%s] forwarded %d bytes", label, n) _, err := dst.Write(buf[:n]) if err != nil { log.Printf("[%s] write error: %v", label, err) break } } } done <- struct{}{} }