mirror of https://github.com/go-gost/gost.git
enhancement for supporting local ssh config, so you can use alias in ssh://<alias>
parent
340ba32ef0
commit
bb043572d3
|
|
@ -36,7 +36,7 @@ func (p *program) Init(env svc.Environment) error {
|
|||
parser.Init(parser.Args{
|
||||
CfgFile: cfgFile,
|
||||
Services: services,
|
||||
Nodes: nodes,
|
||||
Nodes: expandSSHNodes(nodes),
|
||||
Debug: debug,
|
||||
Trace: trace,
|
||||
ApiAddr: apiAddr,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,193 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// sshHostEntry holds the resolved settings for a single SSH host alias.
|
||||
type sshHostEntry struct {
|
||||
hostname string // HostName directive
|
||||
port string // Port directive
|
||||
user string // User directive
|
||||
identityFile string // first IdentityFile directive
|
||||
}
|
||||
|
||||
// parseSSHConfigLine splits an SSH config line into (key, value).
|
||||
// Handles both "Key Value" and "Key=Value" (with optional spaces around =).
|
||||
func parseSSHConfigLine(line string) (key, value string, ok bool) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
return "", "", false
|
||||
}
|
||||
// Find first whitespace or '='
|
||||
i := 0
|
||||
for i < len(line) && line[i] != ' ' && line[i] != '\t' && line[i] != '=' {
|
||||
i++
|
||||
}
|
||||
if i == len(line) {
|
||||
return "", "", false
|
||||
}
|
||||
key = line[:i]
|
||||
rest := strings.TrimLeft(line[i:], " \t=")
|
||||
// Strip trailing inline comment (must be preceded by whitespace)
|
||||
if idx := strings.Index(rest, " #"); idx >= 0 {
|
||||
rest = strings.TrimSpace(rest[:idx])
|
||||
}
|
||||
value = strings.Trim(rest, `"'`)
|
||||
return key, value, value != ""
|
||||
}
|
||||
|
||||
// readSSHConfig parses ~/.ssh/config and returns a map of host alias/pattern
|
||||
// to resolved entry. Only exact-match patterns (no wildcards) are indexed so
|
||||
// they can be looked up directly; wildcard patterns are skipped for now.
|
||||
func readSSHConfig() map[string]*sshHostEntry {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
path := filepath.Join(home, ".ssh", "config")
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
entries := make(map[string]*sshHostEntry)
|
||||
var current *sshHostEntry
|
||||
var currentPatterns []string
|
||||
|
||||
commit := func() {
|
||||
if current == nil {
|
||||
return
|
||||
}
|
||||
for _, pat := range currentPatterns {
|
||||
// Skip wildcard patterns — they can't be used for exact lookup.
|
||||
if strings.ContainsAny(pat, "*?") {
|
||||
continue
|
||||
}
|
||||
if _, exists := entries[pat]; !exists {
|
||||
entries[pat] = current
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
key, value, ok := parseSSHConfigLine(scanner.Text())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(key) {
|
||||
case "host":
|
||||
commit()
|
||||
currentPatterns = strings.Fields(value)
|
||||
current = &sshHostEntry{}
|
||||
case "hostname":
|
||||
if current != nil && current.hostname == "" {
|
||||
current.hostname = value
|
||||
}
|
||||
case "port":
|
||||
if current != nil && current.port == "" {
|
||||
current.port = value
|
||||
}
|
||||
case "user":
|
||||
if current != nil && current.user == "" {
|
||||
current.user = value
|
||||
}
|
||||
case "identityfile":
|
||||
if current != nil && current.identityFile == "" {
|
||||
if strings.HasPrefix(value, "~/") {
|
||||
home, _ := os.UserHomeDir()
|
||||
value = filepath.Join(home, value[2:])
|
||||
}
|
||||
current.identityFile = value
|
||||
}
|
||||
}
|
||||
}
|
||||
commit()
|
||||
|
||||
return entries
|
||||
}
|
||||
|
||||
// expandSSHNode rewrites a single node URL string using ~/.ssh/config when the
|
||||
// scheme is "ssh" and the host matches a Host entry. Fields already present in
|
||||
// the URL are never overridden.
|
||||
func expandSSHNode(raw string, cfg map[string]*sshHostEntry) string {
|
||||
if cfg == nil || !strings.HasPrefix(raw, "ssh://") {
|
||||
return raw
|
||||
}
|
||||
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil || u.Scheme != "ssh" {
|
||||
return raw
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
port := u.Port()
|
||||
|
||||
entry, ok := cfg[host]
|
||||
if !ok {
|
||||
return raw
|
||||
}
|
||||
|
||||
// Resolve hostname.
|
||||
newHost := host
|
||||
if entry.hostname != "" {
|
||||
newHost = entry.hostname
|
||||
}
|
||||
|
||||
// Resolve port: only substitute when the URL carries no explicit port or
|
||||
// carries the SSH default (22) and the config specifies a different one.
|
||||
newPort := port
|
||||
if newPort == "" || newPort == "22" {
|
||||
if entry.port != "" && entry.port != "22" {
|
||||
newPort = entry.port
|
||||
}
|
||||
}
|
||||
|
||||
if newPort != "" && newPort != "22" {
|
||||
u.Host = net.JoinHostPort(newHost, newPort)
|
||||
} else {
|
||||
u.Host = newHost
|
||||
}
|
||||
|
||||
// Apply user if not already present in the URL.
|
||||
if entry.user != "" && (u.User == nil || u.User.Username() == "") {
|
||||
var pw string
|
||||
if u.User != nil {
|
||||
pw, _ = u.User.Password()
|
||||
}
|
||||
if pw != "" {
|
||||
u.User = url.UserPassword(entry.user, pw)
|
||||
} else {
|
||||
u.User = url.User(entry.user)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply identity file as the "key" query parameter when not already set.
|
||||
if entry.identityFile != "" {
|
||||
q := u.Query()
|
||||
if q.Get("privateKeyFile") == "" {
|
||||
q.Set("privateKeyFile", entry.identityFile)
|
||||
u.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// expandSSHNodes applies expandSSHNode to every element of the node list and
|
||||
// returns the result as a new slice. The original slice is not modified.
|
||||
func expandSSHNodes(nodes []string) []string {
|
||||
cfg := readSSHConfig()
|
||||
result := make([]string, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = expandSSHNode(n, cfg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeSSHConfig(t *testing.T, content string) (cleanup func()) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
sshDir := filepath.Join(dir, ".ssh")
|
||||
if err := os.Mkdir(sshDir, 0700); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sshDir, "config"), []byte(content), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
old := os.Getenv("HOME")
|
||||
os.Setenv("HOME", dir)
|
||||
return func() { os.Setenv("HOME", old) }
|
||||
}
|
||||
|
||||
func TestParseSSHConfigLine(t *testing.T) {
|
||||
cases := []struct {
|
||||
line, wantKey, wantVal string
|
||||
wantOk bool
|
||||
}{
|
||||
{" HostName example.com", "HostName", "example.com", true},
|
||||
{"HostName=example.com", "HostName", "example.com", true},
|
||||
{"HostName = example.com", "HostName", "example.com", true},
|
||||
{"# comment", "", "", false},
|
||||
{"", "", "", false},
|
||||
{`IdentityFile "~/.ssh/id_rsa"`, "IdentityFile", "~/.ssh/id_rsa", true},
|
||||
{"Host myalias", "Host", "myalias", true},
|
||||
{"Host alias1 alias2", "Host", "alias1 alias2", true},
|
||||
}
|
||||
for _, c := range cases {
|
||||
k, v, ok := parseSSHConfigLine(c.line)
|
||||
if ok != c.wantOk || k != c.wantKey || v != c.wantVal {
|
||||
t.Errorf("parseSSHConfigLine(%q) = (%q, %q, %v), want (%q, %q, %v)",
|
||||
c.line, k, v, ok, c.wantKey, c.wantVal, c.wantOk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandSSHNode_NoMatch(t *testing.T) {
|
||||
cfg := map[string]*sshHostEntry{
|
||||
"myalias": {hostname: "real.host.com", port: "2222", user: "alice"},
|
||||
}
|
||||
if got := expandSSHNode("http://myalias:8080", cfg); got != "http://myalias:8080" {
|
||||
t.Errorf("non-SSH URL was modified: %q", got)
|
||||
}
|
||||
if got := expandSSHNode("ssh://other:22", cfg); got != "ssh://other:22" {
|
||||
t.Errorf("unknown SSH host was modified: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandSSHNode_FullExpansion(t *testing.T) {
|
||||
keyPath := "/home/alice/.ssh/id_ed25519"
|
||||
cfg := map[string]*sshHostEntry{
|
||||
"myalias": {
|
||||
hostname: "real.host.com",
|
||||
port: "2222",
|
||||
user: "alice",
|
||||
identityFile: keyPath,
|
||||
},
|
||||
}
|
||||
|
||||
got := expandSSHNode("ssh://myalias", cfg)
|
||||
u, err := url.Parse(got)
|
||||
if err != nil {
|
||||
t.Fatalf("parse result URL: %v", err)
|
||||
}
|
||||
if u.Hostname() != "real.host.com" {
|
||||
t.Errorf("hostname: got %q, want %q", u.Hostname(), "real.host.com")
|
||||
}
|
||||
if u.Port() != "2222" {
|
||||
t.Errorf("port: got %q, want %q", u.Port(), "2222")
|
||||
}
|
||||
if u.User.Username() != "alice" {
|
||||
t.Errorf("user: got %q, want %q", u.User.Username(), "alice")
|
||||
}
|
||||
if q := u.Query().Get("privateKeyFile"); q != keyPath {
|
||||
t.Errorf("privateKeyFile param: got %q, want %q", q, keyPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandSSHNode_PreserveExisting(t *testing.T) {
|
||||
cfg := map[string]*sshHostEntry{
|
||||
"myalias": {
|
||||
hostname: "real.host.com",
|
||||
port: "2222",
|
||||
user: "alice",
|
||||
identityFile: "/home/alice/.ssh/id_ed25519",
|
||||
},
|
||||
}
|
||||
|
||||
got := expandSSHNode("ssh://bob@myalias:22?key=/tmp/other_key", cfg)
|
||||
u, err := url.Parse(got)
|
||||
if err != nil {
|
||||
t.Fatalf("parse result URL: %v", err)
|
||||
}
|
||||
if u.Hostname() != "real.host.com" {
|
||||
t.Errorf("hostname: got %q, want %q", u.Hostname(), "real.host.com")
|
||||
}
|
||||
// Port 22 is default; config has 2222 → should be substituted.
|
||||
if u.Port() != "2222" {
|
||||
t.Errorf("port: got %q, want %q", u.Port(), "2222")
|
||||
}
|
||||
// Existing user in URL must be preserved.
|
||||
if u.User.Username() != "bob" {
|
||||
t.Errorf("user: got %q, want %q", u.User.Username(), "bob")
|
||||
}
|
||||
// Existing TLS key param must be preserved unchanged.
|
||||
if q := u.Query().Get("key"); q != "/tmp/other_key" {
|
||||
t.Errorf("key param: got %q, want %q", q, "/tmp/other_key")
|
||||
}
|
||||
// privateKeyFile from config must be added (key is TLS, not SSH).
|
||||
if q := u.Query().Get("privateKeyFile"); q != "/home/alice/.ssh/id_ed25519" {
|
||||
t.Errorf("privateKeyFile param: got %q, want %q", q, "/home/alice/.ssh/id_ed25519")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandSSHNode_PrivateKeyFileInURL_Preserved(t *testing.T) {
|
||||
cfg := map[string]*sshHostEntry{
|
||||
"srv": {hostname: "1.2.3.4", identityFile: "/home/user/.ssh/id_rsa"},
|
||||
}
|
||||
got := expandSSHNode("ssh://srv?privateKeyFile=/custom/key", cfg)
|
||||
u, _ := url.Parse(got)
|
||||
// privateKeyFile already set → identityFile from config should NOT override.
|
||||
if q := u.Query().Get("privateKeyFile"); q != "/custom/key" {
|
||||
t.Errorf("privateKeyFile param: got %q, want %q", q, "/custom/key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSSHConfig(t *testing.T) {
|
||||
cleanup := writeSSHConfig(t, `
|
||||
# global defaults
|
||||
Host *
|
||||
ServerAliveInterval 60
|
||||
|
||||
Host bastion
|
||||
HostName bastion.example.com
|
||||
Port 2222
|
||||
User deploy
|
||||
IdentityFile ~/.ssh/bastion_key
|
||||
|
||||
Host dev
|
||||
HostName 10.0.0.5
|
||||
User dev
|
||||
`)
|
||||
defer cleanup()
|
||||
|
||||
cfg := readSSHConfig()
|
||||
if cfg == nil {
|
||||
t.Fatal("readSSHConfig returned nil")
|
||||
}
|
||||
|
||||
e, ok := cfg["bastion"]
|
||||
if !ok {
|
||||
t.Fatal("expected 'bastion' entry")
|
||||
}
|
||||
if e.hostname != "bastion.example.com" {
|
||||
t.Errorf("hostname: got %q", e.hostname)
|
||||
}
|
||||
if e.port != "2222" {
|
||||
t.Errorf("port: got %q", e.port)
|
||||
}
|
||||
if e.user != "deploy" {
|
||||
t.Errorf("user: got %q", e.user)
|
||||
}
|
||||
if filepath.Base(e.identityFile) != "bastion_key" {
|
||||
t.Errorf("identityFile: got %q", e.identityFile)
|
||||
}
|
||||
|
||||
dev, ok := cfg["dev"]
|
||||
if !ok {
|
||||
t.Fatal("expected 'dev' entry")
|
||||
}
|
||||
if dev.hostname != "10.0.0.5" {
|
||||
t.Errorf("dev hostname: got %q", dev.hostname)
|
||||
}
|
||||
|
||||
// Wildcard Host * must not be indexed by exact key.
|
||||
if _, ok := cfg["*"]; ok {
|
||||
t.Error("wildcard Host * should not appear in entries map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandSSHNodes_Integration(t *testing.T) {
|
||||
cleanup := writeSSHConfig(t, `
|
||||
Host jump
|
||||
HostName jump.corp.example.com
|
||||
Port 2200
|
||||
User ops
|
||||
IdentityFile ~/.ssh/jump_key
|
||||
`)
|
||||
defer cleanup()
|
||||
|
||||
nodes := []string{
|
||||
"ssh://jump",
|
||||
"http://proxy:8080",
|
||||
"ssh://realhost:22",
|
||||
}
|
||||
got := expandSSHNodes(nodes)
|
||||
|
||||
u0, _ := url.Parse(got[0])
|
||||
if u0.Hostname() != "jump.corp.example.com" {
|
||||
t.Errorf("[0] hostname: got %q", u0.Hostname())
|
||||
}
|
||||
if u0.Port() != "2200" {
|
||||
t.Errorf("[0] port: got %q", u0.Port())
|
||||
}
|
||||
if u0.User.Username() != "ops" {
|
||||
t.Errorf("[0] user: got %q", u0.User.Username())
|
||||
}
|
||||
if got[1] != "http://proxy:8080" {
|
||||
t.Errorf("[1] non-SSH URL changed: %q", got[1])
|
||||
}
|
||||
// ssh://realhost:22 — not in config, unchanged.
|
||||
if got[2] != "ssh://realhost:22" {
|
||||
t.Errorf("[2] unknown host changed: %q", got[2])
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
package main
|
||||
|
||||
var (
|
||||
version = "3.2.6"
|
||||
version = "3.2.7"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue