Compare commits
3 Commits
6b87c9e688
...
60ac69621e
| Author | SHA1 | Date | |
|---|---|---|---|
| 60ac69621e | |||
| f18a403545 | |||
| 1795e91387 |
62
auth.go
62
auth.go
@@ -2,15 +2,16 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"net"
|
||||
"x11proxy/x11probe"
|
||||
)
|
||||
|
||||
func resolveDisplay(display string) (ConnType, string) {
|
||||
func resolveDisplay(display string) (x11probe.ConnType, string) {
|
||||
var dispNum int
|
||||
var err error
|
||||
|
||||
@@ -26,25 +27,21 @@ func resolveDisplay(display string) (ConnType, string) {
|
||||
|
||||
unixPath := path.Join("/tmp/.X11-unix", fmt.Sprintf("X%d", dispNum))
|
||||
if _, err := os.Stat(unixPath); err == nil {
|
||||
return Unix, unixPath
|
||||
return x11probe.Unix, unixPath
|
||||
}
|
||||
|
||||
tcpAddr := fmt.Sprintf("127.0.0.1:%d", 6000+dispNum)
|
||||
conn, err := net.Dial("tcp", tcpAddr)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return TCP, tcpAddr
|
||||
return x11probe.TCP, tcpAddr
|
||||
}
|
||||
|
||||
return Unix, unixPath
|
||||
return x11probe.Unix, unixPath
|
||||
}
|
||||
|
||||
func getXAuthCookie(display string) ([]byte, error) {
|
||||
out, err := exec.Command("xauth", "list").Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract display number
|
||||
var dispNum string
|
||||
if strings.HasPrefix(display, ":") {
|
||||
dispNum = strings.Split(display[1:], ".")[0]
|
||||
@@ -53,22 +50,56 @@ func getXAuthCookie(display string) ([]byte, error) {
|
||||
dispNum = strings.Split(parts[1], ".")[0]
|
||||
}
|
||||
|
||||
// Query only entries for that display
|
||||
out, err := exec.Command("xauth", "list", ":"+dispNum).Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xauth list failed: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(out), "\n")
|
||||
var bestMatch string
|
||||
var bestCookie string
|
||||
var bestScore int
|
||||
|
||||
for _, line := range lines {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" {
|
||||
if len(fields) < 3 || fields[1] != "MIT-MAGIC-COOKIE-1" {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := fields[0]
|
||||
if strings.Contains(entry, ":"+dispNum) {
|
||||
fmt.Printf("Using XAuth cookie from entry: %s\n", entry)
|
||||
return parseHexCookie(fields[2]), nil
|
||||
cookie := fields[2]
|
||||
|
||||
// Score based on specificity
|
||||
score := 0
|
||||
switch {
|
||||
case strings.HasPrefix(entry, os.Getenv("HOSTNAME")+"/unix:"):
|
||||
score = 100
|
||||
case strings.HasPrefix(entry, "unix:"):
|
||||
score = 90
|
||||
case strings.Contains(entry, "/unix:"):
|
||||
score = 80
|
||||
case strings.HasPrefix(entry, "#ffff#"):
|
||||
score = 70
|
||||
default:
|
||||
score = 50
|
||||
}
|
||||
|
||||
if score > bestScore {
|
||||
bestMatch = entry
|
||||
bestCookie = cookie
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestCookie != "" {
|
||||
fmt.Printf("Using XAuth cookie from entry: %s\n", bestMatch)
|
||||
return parseHexCookie(bestCookie), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no matching cookie found for display %s", display)
|
||||
}
|
||||
|
||||
|
||||
func parseHexCookie(hexStr string) []byte {
|
||||
var cookie []byte
|
||||
for i := 0; i < len(hexStr); i += 2 {
|
||||
@@ -113,4 +144,3 @@ func PatchAuth(data []byte, cookie []byte) []byte {
|
||||
|
||||
return patched
|
||||
}
|
||||
|
||||
|
||||
29
main.go
29
main.go
@@ -6,6 +6,8 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
"x11proxy/x11probe"
|
||||
)
|
||||
|
||||
const version = "0.1"
|
||||
@@ -37,17 +39,36 @@ func main() {
|
||||
display = *overrideDisplay
|
||||
}
|
||||
|
||||
if display == "" {
|
||||
log.Fatalf("No DISPLAY environment variable set and no override provided. Please set DISPLAY or use --display.")
|
||||
}
|
||||
|
||||
connType, target := resolveDisplay(display)
|
||||
fmt.Printf("Proxying to %s (%s)\n", target, connTypeString(connType))
|
||||
|
||||
err := StartProxy(*overrideSocket, target, connType, display)
|
||||
timeout := 10 * time.Second
|
||||
status, err := x11probe.ProbeX11Socket(connType, target, timeout)
|
||||
if err != nil {
|
||||
log.Fatalf("Proxy error: %v", err)
|
||||
log.Fatalf("Connection probe error: %v", err)
|
||||
}
|
||||
|
||||
switch status {
|
||||
case x11probe.X11Live:
|
||||
log.Println("X11 socket is live and accepting connections")
|
||||
case x11probe.X11AuthRequired:
|
||||
log.Println("X11 socket is live but requires authentication")
|
||||
case x11probe.X11Dead:
|
||||
log.Fatalf("Target connection is dead or unreachable")
|
||||
}
|
||||
|
||||
err2 := StartProxy(*overrideSocket, target, connType, display)
|
||||
if err2 != nil {
|
||||
log.Fatalf("Proxy error: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
func connTypeString(t ConnType) string {
|
||||
if t == Unix {
|
||||
func connTypeString(t x11probe.ConnType) string {
|
||||
if t == x11probe.Unix {
|
||||
return "Unix socket"
|
||||
}
|
||||
return "TCP"
|
||||
|
||||
56
proxy.go
56
proxy.go
@@ -8,16 +8,11 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
"x11proxy/x11probe"
|
||||
)
|
||||
|
||||
type ConnType int
|
||||
|
||||
const (
|
||||
Unix ConnType = iota
|
||||
TCP
|
||||
)
|
||||
|
||||
func StartProxy(proxyPath, target string, connType ConnType, display string) error {
|
||||
func StartProxy(proxyPath, target string, connType x11probe.ConnType, display string) error {
|
||||
os.Remove(proxyPath)
|
||||
listener, err := net.Listen("unix", proxyPath)
|
||||
if err != nil {
|
||||
@@ -36,11 +31,18 @@ func StartProxy(proxyPath, target string, connType ConnType, display string) err
|
||||
go handleConnection(clientConn, connType, target, display)
|
||||
}
|
||||
}
|
||||
|
||||
func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
|
||||
// Step 1: Read fixed-length header
|
||||
// If r is a net.Conn, set a read deadline
|
||||
if conn, ok := r.(net.Conn); ok {
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
conn.SetReadDeadline(deadline)
|
||||
}
|
||||
|
||||
// Step 1: Read fixed-length header manually
|
||||
header := make([]byte, 12)
|
||||
if _, err := io.ReadFull(r, header); err != nil {
|
||||
n, err := io.ReadFull(r, header)
|
||||
if err != nil {
|
||||
log.Printf("Handshake header read failed after %d bytes: %v", n, err)
|
||||
return nil, fmt.Errorf("failed to read handshake header: %w", err)
|
||||
}
|
||||
|
||||
@@ -58,13 +60,19 @@ func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
|
||||
authProtoLen := read16(header[6:8])
|
||||
authDataLen := read16(header[8:10])
|
||||
|
||||
// Step 3: Read remaining fields
|
||||
totalLen := int(authProtoLen+authDataLen)
|
||||
totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen))
|
||||
// Step 3 A: Read remaining fields
|
||||
totalLen := padLen(int(authProtoLen)) + padLen(int(authDataLen))
|
||||
extra := make([]byte, totalLen)
|
||||
if _, err := io.ReadFull(r, extra); err != nil {
|
||||
|
||||
n, err = io.ReadFull(r, extra)
|
||||
if err != nil {
|
||||
log.Printf("Handshake auth read failed after %d bytes: %v", n, err)
|
||||
return nil, fmt.Errorf("failed to read handshake auth fields: %w", err)
|
||||
}
|
||||
// Step 3 B: clear timeout/deadline
|
||||
if conn, ok := r.(net.Conn); ok {
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
// Step 4: Decide whether to patch
|
||||
patch := authProtoLen == 0 || authDataLen == 0
|
||||
@@ -82,7 +90,6 @@ func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
|
||||
newAuthProtoLen := uint16(len(authProto))
|
||||
newAuthDataLen := uint16(len(authData))
|
||||
|
||||
// Patch header in-place
|
||||
if isLittle {
|
||||
binary.LittleEndian.PutUint16(header[6:8], newAuthProtoLen)
|
||||
binary.LittleEndian.PutUint16(header[8:10], newAuthDataLen)
|
||||
@@ -113,12 +120,12 @@ func binaryOrder(isLittle bool) binary.ByteOrder {
|
||||
return binary.BigEndian
|
||||
}
|
||||
|
||||
|
||||
func handleConnection(client net.Conn, connType ConnType, target string, display string) {
|
||||
func handleConnection(client net.Conn, connType x11probe.ConnType, target string, display string) {
|
||||
var serverConn net.Conn
|
||||
var err error
|
||||
var patch bool
|
||||
|
||||
if connType == Unix {
|
||||
if connType == x11probe.Unix {
|
||||
serverConn, err = net.Dial("unix", target)
|
||||
} else {
|
||||
serverConn, err = net.Dial("tcp", target)
|
||||
@@ -130,15 +137,15 @@ func handleConnection(client net.Conn, connType ConnType, target string, display
|
||||
}
|
||||
|
||||
log.Printf("New connection from %v", client.RemoteAddr())
|
||||
patch = true
|
||||
|
||||
cookie, err := getXAuthCookie(display)
|
||||
if err != nil || cookie == nil {
|
||||
log.Printf("Failed to get XAuth cookie")
|
||||
client.Close()
|
||||
serverConn.Close()
|
||||
return
|
||||
patch = false
|
||||
log.Printf("Failed to get XAuth cookie, try forwarding without patching")
|
||||
}
|
||||
|
||||
if patch {
|
||||
log.Printf("About to read the handshake")
|
||||
patched, err := readAndPatchHandshake(client, cookie)
|
||||
if err != nil {
|
||||
@@ -148,7 +155,6 @@ func handleConnection(client net.Conn, connType ConnType, target string, display
|
||||
serverConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_, err = serverConn.Write(patched)
|
||||
if err != nil {
|
||||
log.Printf("Initial write failed: %v", err)
|
||||
@@ -156,7 +162,7 @@ func handleConnection(client net.Conn, connType ConnType, target string, display
|
||||
serverConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go inspectAndForward(serverConn, client, "client→server", done)
|
||||
|
||||
125
x11probe/probe.go
Normal file
125
x11probe/probe.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package x11probe
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ConnType int
|
||||
|
||||
const (
|
||||
Unix ConnType = iota
|
||||
TCP
|
||||
)
|
||||
|
||||
// X11Status represents the result of probing an X11 socket.
|
||||
type X11Status int
|
||||
|
||||
const (
|
||||
X11Dead X11Status = iota
|
||||
X11Live
|
||||
X11AuthRequired
|
||||
)
|
||||
|
||||
func ProbeX11Socket(connType ConnType, target string, timeout time.Duration) (X11Status, error) {
|
||||
var (
|
||||
conn net.Conn
|
||||
err error
|
||||
)
|
||||
|
||||
switch connType {
|
||||
case Unix:
|
||||
conn, err = net.DialTimeout("unix", target, timeout)
|
||||
if err != nil {
|
||||
return X11Dead, fmt.Errorf("unix dial failed: %w", err)
|
||||
}
|
||||
case TCP:
|
||||
addr := parseTCPAddress(target)
|
||||
conn, err = net.DialTimeout("tcp", addr, timeout)
|
||||
if err != nil {
|
||||
return X11Dead, fmt.Errorf("TCP dial failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
log.Printf("Successfully connected to %s", target)
|
||||
|
||||
if err := performHandshake(conn); err != nil {
|
||||
return X11Dead, err
|
||||
}
|
||||
|
||||
status, err := readResponse(conn, timeout)
|
||||
if err != nil {
|
||||
return X11Dead, err
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func parseTCPAddress(target string) string {
|
||||
parts := strings.SplitN(target, ":", 2)
|
||||
|
||||
ipPart := parts[0]
|
||||
portPart := "0"
|
||||
|
||||
if len(parts) == 2 && portPart != "" {
|
||||
portPart = parts[1]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s", ipPart, portPart)
|
||||
}
|
||||
|
||||
func performHandshake(conn net.Conn) error {
|
||||
// Minimal valid handshake with correct padding
|
||||
handshake := []byte{
|
||||
'l', 0, // Byte order
|
||||
0, 11, 0, 0, // Protocol major/minor
|
||||
0, 0, // Auth name length
|
||||
0, 0, // Auth data length
|
||||
0, 0, 0, 0, // Padding
|
||||
}
|
||||
|
||||
if _, err := conn.Write(handshake); err != nil {
|
||||
return fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readResponse(conn net.Conn, timeout time.Duration) (X11Status, error) {
|
||||
buf := make([]byte, 1)
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return X11Dead, fmt.Errorf("SetReadDeadline failed: %w", err)
|
||||
}
|
||||
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return X11AuthRequired, nil
|
||||
}
|
||||
return X11Dead, fmt.Errorf("read failed: %w", err)
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return X11Dead, fmt.Errorf("read returned 0 bytes")
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 0x01:
|
||||
return X11Live, nil
|
||||
case 0x02:
|
||||
return X11AuthRequired, nil
|
||||
case 0x00:
|
||||
fmt.Printf("X11 handshake rejected (value %x)\n", buf[0])
|
||||
return X11AuthRequired, nil
|
||||
default:
|
||||
fmt.Printf("unexpected value %x\n", buf[0])
|
||||
return X11Dead, nil
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user