add X11 socket/port probe
This commit is contained in:
175
auth.go
175
auth.go
@@ -1,116 +1,115 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"x11proxy/x11probe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func resolveDisplay(display string) (ConnType, string) {
|
func resolveDisplay(display string) (x11probe.ConnType, string) {
|
||||||
var dispNum int
|
var dispNum int
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if strings.HasPrefix(display, ":") {
|
if strings.HasPrefix(display, ":") {
|
||||||
dispNum, err = strconv.Atoi(strings.Split(display[1:], ".")[0])
|
dispNum, err = strconv.Atoi(strings.Split(display[1:], ".")[0])
|
||||||
} else {
|
} else {
|
||||||
parts := strings.Split(display, ":")
|
parts := strings.Split(display, ":")
|
||||||
dispNum, err = strconv.Atoi(strings.Split(parts[1], ".")[0])
|
dispNum, err = strconv.Atoi(strings.Split(parts[1], ".")[0])
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dispNum = 0
|
dispNum = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
unixPath := path.Join("/tmp/.X11-unix", fmt.Sprintf("X%d", dispNum))
|
unixPath := path.Join("/tmp/.X11-unix", fmt.Sprintf("X%d", dispNum))
|
||||||
if _, err := os.Stat(unixPath); err == nil {
|
if _, err := os.Stat(unixPath); err == nil {
|
||||||
return Unix, unixPath
|
return x11probe.Unix, unixPath
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpAddr := fmt.Sprintf("127.0.0.1:%d", 6000+dispNum)
|
tcpAddr := fmt.Sprintf("127.0.0.1:%d", 6000+dispNum)
|
||||||
conn, err := net.Dial("tcp", tcpAddr)
|
conn, err := net.Dial("tcp", tcpAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return TCP, tcpAddr
|
return x11probe.TCP, tcpAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
return Unix, unixPath
|
return x11probe.Unix, unixPath
|
||||||
}
|
}
|
||||||
|
|
||||||
func getXAuthCookie(display string) ([]byte, error) {
|
func getXAuthCookie(display string) ([]byte, error) {
|
||||||
out, err := exec.Command("xauth", "list").Output()
|
out, err := exec.Command("xauth", "list").Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var dispNum string
|
var dispNum string
|
||||||
if strings.HasPrefix(display, ":") {
|
if strings.HasPrefix(display, ":") {
|
||||||
dispNum = strings.Split(display[1:], ".")[0]
|
dispNum = strings.Split(display[1:], ".")[0]
|
||||||
} else {
|
} else {
|
||||||
parts := strings.Split(display, ":")
|
parts := strings.Split(display, ":")
|
||||||
dispNum = strings.Split(parts[1], ".")[0]
|
dispNum = strings.Split(parts[1], ".")[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.Split(string(out), "\n")
|
lines := strings.Split(string(out), "\n")
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" {
|
if len(fields) >= 3 && fields[1] == "MIT-MAGIC-COOKIE-1" {
|
||||||
entry := fields[0]
|
entry := fields[0]
|
||||||
if strings.Contains(entry, ":"+dispNum) {
|
if strings.Contains(entry, ":"+dispNum) {
|
||||||
fmt.Printf("Using XAuth cookie from entry: %s\n", entry)
|
fmt.Printf("Using XAuth cookie from entry: %s\n", entry)
|
||||||
return parseHexCookie(fields[2]), nil
|
return parseHexCookie(fields[2]), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("no matching cookie found for display %s", display)
|
return nil, fmt.Errorf("no matching cookie found for display %s", display)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func parseHexCookie(hexStr string) []byte {
|
func parseHexCookie(hexStr string) []byte {
|
||||||
var cookie []byte
|
var cookie []byte
|
||||||
for i := 0; i < len(hexStr); i += 2 {
|
for i := 0; i < len(hexStr); i += 2 {
|
||||||
var val byte
|
var val byte
|
||||||
fmt.Sscanf(hexStr[i:i+2], "%02x", &val)
|
fmt.Sscanf(hexStr[i:i+2], "%02x", &val)
|
||||||
cookie = append(cookie, val)
|
cookie = append(cookie, val)
|
||||||
}
|
}
|
||||||
return cookie
|
return cookie
|
||||||
}
|
}
|
||||||
|
|
||||||
func PatchAuth(data []byte, cookie []byte) []byte {
|
func PatchAuth(data []byte, cookie []byte) []byte {
|
||||||
isLittleEndian := data[0] == 'l'
|
isLittleEndian := data[0] == 'l'
|
||||||
|
|
||||||
var authProtoLen, authDataLen int
|
var authProtoLen, authDataLen int
|
||||||
if isLittleEndian {
|
if isLittleEndian {
|
||||||
authProtoLen = int(data[7])<<8 | int(data[6])
|
authProtoLen = int(data[7])<<8 | int(data[6])
|
||||||
authDataLen = int(data[9])<<8 | int(data[8])
|
authDataLen = int(data[9])<<8 | int(data[8])
|
||||||
} else {
|
} else {
|
||||||
authProtoLen = int(data[6])<<8 | int(data[7])
|
authProtoLen = int(data[6])<<8 | int(data[7])
|
||||||
authDataLen = int(data[8])<<8 | int(data[9])
|
authDataLen = int(data[8])<<8 | int(data[9])
|
||||||
}
|
}
|
||||||
|
|
||||||
headerLen := 12
|
headerLen := 12
|
||||||
authProtoPad := (authProtoLen + 3) & ^3
|
authProtoPad := (authProtoLen + 3) & ^3
|
||||||
authDataPad := (authDataLen + 3) & ^3
|
authDataPad := (authDataLen + 3) & ^3
|
||||||
authDataStart := headerLen + authProtoPad
|
authDataStart := headerLen + authProtoPad
|
||||||
|
|
||||||
// Replace cookie and update length
|
// Replace cookie and update length
|
||||||
patched := make([]byte, headerLen+authProtoPad+authDataPad)
|
patched := make([]byte, headerLen+authProtoPad+authDataPad)
|
||||||
copy(patched, data[:headerLen+authProtoPad])
|
copy(patched, data[:headerLen+authProtoPad])
|
||||||
copy(patched[authDataStart:], cookie)
|
copy(patched[authDataStart:], cookie)
|
||||||
|
|
||||||
// Update authDataLen to match cookie length
|
// Update authDataLen to match cookie length
|
||||||
cookieLen := len(cookie)
|
cookieLen := len(cookie)
|
||||||
if isLittleEndian {
|
if isLittleEndian {
|
||||||
patched[8] = byte(cookieLen)
|
patched[8] = byte(cookieLen)
|
||||||
patched[9] = byte(cookieLen >> 8)
|
patched[9] = byte(cookieLen >> 8)
|
||||||
} else {
|
} else {
|
||||||
patched[8] = byte(cookieLen >> 8)
|
patched[8] = byte(cookieLen >> 8)
|
||||||
patched[9] = byte(cookieLen)
|
patched[9] = byte(cookieLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
return patched
|
return patched
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
144
main.go
144
main.go
@@ -1,87 +1,113 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
"x11proxy/x11probe"
|
||||||
)
|
)
|
||||||
|
|
||||||
const version = "0.1"
|
const version = "0.1"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
overrideDisplay = flag.String("display", "", "Override DISPLAY")
|
overrideDisplay = flag.String("display", "", "Override DISPLAY")
|
||||||
overrideSocket = flag.String("proxy-socket", "/tmp/.X11-unix/X5", "Proxy socket path")
|
overrideSocket = flag.String("proxy-socket", "/tmp/.X11-unix/X5", "Proxy socket path")
|
||||||
showVersion = flag.Bool("version", false, "Show version and exit")
|
showVersion = flag.Bool("version", false, "Show version and exit")
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
fmt.Fprintf(os.Stderr, "Usage of %s:\n\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, "Usage of %s:\n\n", os.Args[0])
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *showVersion {
|
if *showVersion {
|
||||||
fmt.Printf("X11Proxy version %s\n", version)
|
fmt.Printf("X11Proxy version %s\n", version)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Printf("version %s\n", version)
|
fmt.Printf("version %s\n", version)
|
||||||
|
|
||||||
display := os.Getenv("DISPLAY")
|
display := os.Getenv("DISPLAY")
|
||||||
if *overrideDisplay != "" {
|
if *overrideDisplay != "" {
|
||||||
display = *overrideDisplay
|
display = *overrideDisplay
|
||||||
}
|
}
|
||||||
|
|
||||||
connType, target := resolveDisplay(display)
|
connType, target := resolveDisplay(display)
|
||||||
fmt.Printf("Proxying to %s (%s)\n", target, connTypeString(connType))
|
fmt.Printf("Proxying to %s (%s)\n", target, connTypeString(connType))
|
||||||
|
|
||||||
err := StartProxy(*overrideSocket, target, connType, display)
|
timeout := 10 * time.Second
|
||||||
if err != nil {
|
status, err := x11probe.ProbeX11Socket(connType, target, timeout)
|
||||||
log.Fatalf("Proxy error: %v", err)
|
if err != nil {
|
||||||
}
|
log.Fatalf("Connection probe error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case x11probe.X11Dead:
|
||||||
|
log.Fatalf("Target connection is dead or unreachable")
|
||||||
|
case x11probe.X11AuthRequired:
|
||||||
|
fmt.Println("Authentication required for the target connection.")
|
||||||
|
case x11probe.X11Live:
|
||||||
|
fmt.Println("Connection to target is live and valid.")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case x11probe.X11Live:
|
||||||
|
log.Println("X11 socket is live and accepting connections")
|
||||||
|
case x11probe.X11AuthRequired:
|
||||||
|
log.Println("X11 socket is live but requires authentication")
|
||||||
|
case x11probe.X11Dead:
|
||||||
|
log.Fatalf("Target connection is dead or unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
err2 := StartProxy(*overrideSocket, target, connType, display)
|
||||||
|
if err2 != nil {
|
||||||
|
log.Fatalf("Proxy error: %v", err2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func connTypeString(t ConnType) string {
|
func connTypeString(t x11probe.ConnType) string {
|
||||||
if t == Unix {
|
if t == x11probe.Unix {
|
||||||
return "Unix socket"
|
return "Unix socket"
|
||||||
}
|
}
|
||||||
return "TCP"
|
return "TCP"
|
||||||
}
|
}
|
||||||
|
|
||||||
func hexDump(buf []byte) string {
|
func hexDump(buf []byte) string {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
for i := 0; i < len(buf); i += 16 {
|
for i := 0; i < len(buf); i += 16 {
|
||||||
line := fmt.Sprintf("%08x ", i)
|
line := fmt.Sprintf("%08x ", i)
|
||||||
|
|
||||||
for j := 0; j < 16; j++ {
|
for j := 0; j < 16; j++ {
|
||||||
if i+j < len(buf) {
|
if i+j < len(buf) {
|
||||||
line += fmt.Sprintf("%02x ", buf[i+j])
|
line += fmt.Sprintf("%02x ", buf[i+j])
|
||||||
} else {
|
} else {
|
||||||
line += " "
|
line += " "
|
||||||
}
|
}
|
||||||
if j == 7 {
|
if j == 7 {
|
||||||
line += " "
|
line += " "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
line += " |"
|
line += " |"
|
||||||
|
|
||||||
for j := 0; j < 16 && i+j < len(buf); j++ {
|
for j := 0; j < 16 && i+j < len(buf); j++ {
|
||||||
b := buf[i+j]
|
b := buf[i+j]
|
||||||
if b >= 32 && b <= 126 {
|
if b >= 32 && b <= 126 {
|
||||||
line += string(b)
|
line += string(b)
|
||||||
} else {
|
} else {
|
||||||
line += "."
|
line += "."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
line += "|\n"
|
line += "|\n"
|
||||||
out.WriteString(line)
|
out.WriteString(line)
|
||||||
}
|
}
|
||||||
return out.String()
|
return out.String()
|
||||||
}
|
}
|
||||||
|
|||||||
257
proxy.go
257
proxy.go
@@ -1,86 +1,80 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"x11proxy/x11probe"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnType int
|
func StartProxy(proxyPath, target string, connType x11probe.ConnType, display string) error {
|
||||||
|
os.Remove(proxyPath)
|
||||||
|
listener, err := net.Listen("unix", proxyPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
os.Chmod(proxyPath, 0700)
|
||||||
|
|
||||||
const (
|
log.Printf("Proxy listening on %s", proxyPath)
|
||||||
Unix ConnType = iota
|
|
||||||
TCP
|
|
||||||
)
|
|
||||||
|
|
||||||
func StartProxy(proxyPath, target string, connType ConnType, display string) error {
|
for {
|
||||||
os.Remove(proxyPath)
|
clientConn, err := listener.Accept()
|
||||||
listener, err := net.Listen("unix", proxyPath)
|
if err != nil {
|
||||||
if err != nil {
|
log.Printf("Accept error: %v", err)
|
||||||
return err
|
continue
|
||||||
}
|
}
|
||||||
os.Chmod(proxyPath, 0700)
|
go handleConnection(clientConn, connType, target, display)
|
||||||
|
}
|
||||||
log.Printf("Proxy listening on %s", proxyPath)
|
|
||||||
|
|
||||||
for {
|
|
||||||
clientConn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Accept error: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go handleConnection(clientConn, connType, target, display)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
|
func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
|
||||||
// Step 1: Read fixed-length header
|
// Step 1: Read fixed-length header
|
||||||
header := make([]byte, 12)
|
header := make([]byte, 12)
|
||||||
if _, err := io.ReadFull(r, header); err != nil {
|
if _, err := io.ReadFull(r, header); err != nil {
|
||||||
return nil, fmt.Errorf("failed to read handshake header: %w", err)
|
return nil, fmt.Errorf("failed to read handshake header: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
byteOrder := header[0]
|
byteOrder := header[0]
|
||||||
isLittle := byteOrder == 'l'
|
isLittle := byteOrder == 'l'
|
||||||
|
|
||||||
// Step 2: Parse lengths
|
// Step 2: Parse lengths
|
||||||
read16 := func(b []byte) uint16 {
|
read16 := func(b []byte) uint16 {
|
||||||
if isLittle {
|
if isLittle {
|
||||||
return binary.LittleEndian.Uint16(b)
|
return binary.LittleEndian.Uint16(b)
|
||||||
}
|
}
|
||||||
return binary.BigEndian.Uint16(b)
|
return binary.BigEndian.Uint16(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
authProtoLen := read16(header[6:8])
|
authProtoLen := read16(header[6:8])
|
||||||
authDataLen := read16(header[8:10])
|
authDataLen := read16(header[8:10])
|
||||||
|
|
||||||
// Step 3: Read remaining fields
|
// Step 3: Read remaining fields
|
||||||
totalLen := int(authProtoLen+authDataLen)
|
totalLen := int(authProtoLen + authDataLen)
|
||||||
totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen))
|
totalLen += padLen(int(authProtoLen)) + padLen(int(authDataLen))
|
||||||
extra := make([]byte, totalLen)
|
extra := make([]byte, totalLen)
|
||||||
if _, err := io.ReadFull(r, extra); err != nil {
|
if _, err := io.ReadFull(r, extra); err != nil {
|
||||||
return nil, fmt.Errorf("failed to read handshake auth fields: %w", err)
|
return nil, fmt.Errorf("failed to read handshake auth fields: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 4: Decide whether to patch
|
// Step 4: Decide whether to patch
|
||||||
patch := authProtoLen == 0 || authDataLen == 0
|
patch := authProtoLen == 0 || authDataLen == 0
|
||||||
|
|
||||||
var authProto, authData []byte
|
var authProto, authData []byte
|
||||||
if patch {
|
if patch {
|
||||||
authProto = []byte("MIT-MAGIC-COOKIE-1")
|
authProto = []byte("MIT-MAGIC-COOKIE-1")
|
||||||
authData = cookie
|
authData = cookie
|
||||||
} else {
|
} else {
|
||||||
authProto = extra[:authProtoLen]
|
authProto = extra[:authProtoLen]
|
||||||
authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)]
|
authData = extra[padLen(int(authProtoLen)) : padLen(int(authProtoLen))+int(authDataLen)]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 5: Rebuild handshake
|
// Step 5: Rebuild handshake
|
||||||
newAuthProtoLen := uint16(len(authProto))
|
newAuthProtoLen := uint16(len(authProto))
|
||||||
newAuthDataLen := uint16(len(authData))
|
newAuthDataLen := uint16(len(authData))
|
||||||
|
|
||||||
// Patch header in-place
|
// Patch header in-place
|
||||||
if isLittle {
|
if isLittle {
|
||||||
@@ -91,53 +85,52 @@ func readAndPatchHandshake(r io.Reader, cookie []byte) ([]byte, error) {
|
|||||||
binary.BigEndian.PutUint16(header[8:10], newAuthDataLen)
|
binary.BigEndian.PutUint16(header[8:10], newAuthDataLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := bytes.NewBuffer(header)
|
buf := bytes.NewBuffer(header)
|
||||||
buf.Write(authProto)
|
buf.Write(authProto)
|
||||||
buf.Write(make([]byte, padLen(len(authProto))-len(authProto)))
|
buf.Write(make([]byte, padLen(len(authProto))-len(authProto)))
|
||||||
buf.Write(authData)
|
buf.Write(authData)
|
||||||
buf.Write(make([]byte, padLen(len(authData))-len(authData)))
|
buf.Write(make([]byte, padLen(len(authData))-len(authData)))
|
||||||
|
|
||||||
log.Printf("Rebuilt handshake: authProtoLen=%d, authDataLen=%d\n%s",
|
log.Printf("Rebuilt handshake: authProtoLen=%d, authDataLen=%d\n%s",
|
||||||
newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes()))
|
newAuthProtoLen, newAuthDataLen, hexDump(buf.Bytes()))
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func padLen(n int) int {
|
func padLen(n int) int {
|
||||||
return (n + 3) &^ 3 // round up to next multiple of 4
|
return (n + 3) &^ 3 // round up to next multiple of 4
|
||||||
}
|
}
|
||||||
|
|
||||||
func binaryOrder(isLittle bool) binary.ByteOrder {
|
func binaryOrder(isLittle bool) binary.ByteOrder {
|
||||||
if isLittle {
|
if isLittle {
|
||||||
return binary.LittleEndian
|
return binary.LittleEndian
|
||||||
}
|
}
|
||||||
return binary.BigEndian
|
return binary.BigEndian
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleConnection(client net.Conn, connType x11probe.ConnType, target string, display string) {
|
||||||
|
var serverConn net.Conn
|
||||||
|
var err error
|
||||||
|
|
||||||
func handleConnection(client net.Conn, connType ConnType, target string, display string) {
|
if connType == x11probe.Unix {
|
||||||
var serverConn net.Conn
|
serverConn, err = net.Dial("unix", target)
|
||||||
var err error
|
} else {
|
||||||
|
serverConn, err = net.Dial("tcp", target)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to connect to target: %v", err)
|
||||||
|
client.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if connType == Unix {
|
log.Printf("New connection from %v", client.RemoteAddr())
|
||||||
serverConn, err = net.Dial("unix", target)
|
|
||||||
} else {
|
|
||||||
serverConn, err = net.Dial("tcp", target)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to connect to target: %v", err)
|
|
||||||
client.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("New connection from %v", client.RemoteAddr())
|
cookie, err := getXAuthCookie(display)
|
||||||
|
if err != nil || cookie == nil {
|
||||||
cookie, err := getXAuthCookie(display)
|
log.Printf("Failed to get XAuth cookie")
|
||||||
if err != nil || cookie == nil {
|
client.Close()
|
||||||
log.Printf("Failed to get XAuth cookie")
|
serverConn.Close()
|
||||||
client.Close()
|
return
|
||||||
serverConn.Close()
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("About to read the handshake")
|
log.Printf("About to read the handshake")
|
||||||
patched, err := readAndPatchHandshake(client, cookie)
|
patched, err := readAndPatchHandshake(client, cookie)
|
||||||
@@ -148,44 +141,44 @@ func handleConnection(client net.Conn, connType ConnType, target string, display
|
|||||||
serverConn.Close()
|
serverConn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = serverConn.Write(patched)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Initial write failed: %v", err)
|
|
||||||
client.Close()
|
|
||||||
serverConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{}, 2)
|
_, err = serverConn.Write(patched)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Initial write failed: %v", err)
|
||||||
|
client.Close()
|
||||||
|
serverConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
go inspectAndForward(serverConn, client, "client→server", done)
|
done := make(chan struct{}, 2)
|
||||||
go inspectAndForward(client, serverConn, "server→client", done)
|
|
||||||
|
|
||||||
<-done
|
go inspectAndForward(serverConn, client, "client→server", done)
|
||||||
client.Close()
|
go inspectAndForward(client, serverConn, "server→client", done)
|
||||||
serverConn.Close()
|
|
||||||
<-done
|
|
||||||
|
|
||||||
log.Printf("Connection closed: %v", client.RemoteAddr())
|
<-done
|
||||||
|
client.Close()
|
||||||
|
serverConn.Close()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
log.Printf("Connection closed: %v", client.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func inspectAndForward(dst net.Conn, src net.Conn, label string, done chan<- struct{}) {
|
func inspectAndForward(dst net.Conn, src net.Conn, label string, done chan<- struct{}) {
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
for {
|
for {
|
||||||
n, err := src.Read(buf)
|
n, err := src.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[%s] disconnected: %v", label, err)
|
log.Printf("[%s] disconnected: %v", label, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
//log.Printf("[%s] forwarded %d bytes", label, n)
|
//log.Printf("[%s] forwarded %d bytes", label, n)
|
||||||
_, err := dst.Write(buf[:n])
|
_, err := dst.Write(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[%s] write error: %v", label, err)
|
log.Printf("[%s] write error: %v", label, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
done <- struct{}{}
|
done <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|||||||
125
x11probe/probe.go
Normal file
125
x11probe/probe.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package x11probe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Unix ConnType = iota
|
||||||
|
TCP
|
||||||
|
)
|
||||||
|
|
||||||
|
// X11Status represents the result of probing an X11 socket.
|
||||||
|
type X11Status int
|
||||||
|
|
||||||
|
const (
|
||||||
|
X11Dead X11Status = iota
|
||||||
|
X11Live
|
||||||
|
X11AuthRequired
|
||||||
|
)
|
||||||
|
|
||||||
|
func ProbeX11Socket(connType ConnType, target string, timeout time.Duration) (X11Status, error) {
|
||||||
|
var (
|
||||||
|
conn net.Conn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
switch connType {
|
||||||
|
case Unix:
|
||||||
|
conn, err = net.DialTimeout("unix", target, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return X11Dead, fmt.Errorf("unix dial failed: %w", err)
|
||||||
|
}
|
||||||
|
case TCP:
|
||||||
|
addr := parseTCPAddress(target)
|
||||||
|
conn, err = net.DialTimeout("tcp", addr, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return X11Dead, fmt.Errorf("TCP dial failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer conn.Close()
|
||||||
|
log.Printf("Successfully connected to %s", target)
|
||||||
|
|
||||||
|
if err := performHandshake(conn); err != nil {
|
||||||
|
return X11Dead, err
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := readResponse(conn, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return X11Dead, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTCPAddress(target string) string {
|
||||||
|
parts := strings.SplitN(target, ":", 2)
|
||||||
|
|
||||||
|
ipPart := parts[0]
|
||||||
|
portPart := "0"
|
||||||
|
|
||||||
|
if len(parts) == 2 && portPart != "" {
|
||||||
|
portPart = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s:%s", ipPart, portPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func performHandshake(conn net.Conn) error {
|
||||||
|
// Minimal valid handshake with correct padding
|
||||||
|
handshake := []byte{
|
||||||
|
'l', 0, // Byte order
|
||||||
|
0, 11, 0, 0, // Protocol major/minor
|
||||||
|
0, 0, // Auth name length
|
||||||
|
0, 0, // Auth data length
|
||||||
|
0, 0, 0, 0, // Padding
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write(handshake); err != nil {
|
||||||
|
return fmt.Errorf("write failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readResponse(conn net.Conn, timeout time.Duration) (X11Status, error) {
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||||
|
return X11Dead, fmt.Errorf("SetReadDeadline failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return X11AuthRequired, nil
|
||||||
|
}
|
||||||
|
return X11Dead, fmt.Errorf("read failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return X11Dead, fmt.Errorf("read returned 0 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch buf[0] {
|
||||||
|
case 0x01:
|
||||||
|
return X11Live, nil
|
||||||
|
case 0x02:
|
||||||
|
return X11AuthRequired, nil
|
||||||
|
case 0x00:
|
||||||
|
fmt.Printf("X11 handshake rejected (value %x)\n", buf[0])
|
||||||
|
return X11AuthRequired, nil
|
||||||
|
default:
|
||||||
|
fmt.Printf("unexpected value %x\n", buf[0])
|
||||||
|
return X11Dead, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user