-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathutils.go
More file actions
183 lines (162 loc) · 4.83 KB
/
utils.go
File metadata and controls
183 lines (162 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Copyright 2015-2025 CEA/DAM/DIF
// Author: Arnaud Guignard <arnaud.guignard@cea.fr>
// Contributor: Cyril Servant <cyril.servant@cea.fr>
//
// This software is governed by the CeCILL-B license under French law and
// abiding by the rules of distribution of free software. You can use,
// modify and/ or redistribute the software under the terms of the CeCILL-B
// license as circulated by CEA, CNRS and INRIA at the following URL
// "http://www.cecill.info".
package utils
import (
"crypto/sha1"
"fmt"
"net"
"os"
"os/user"
"sort"
"strconv"
"strings"
"time"
)
// DefaultSSHPort is the default SSH server port.
const DefaultSSHPort = "22"
// DefaultService is the default service name.
const DefaultService = "default"
// CalcSessionID returns a unique 10 hexadecimal characters string from
// a user name, time, ip address and port.
func CalcSessionID(user string, t time.Time, hostport string) string {
sum := sha1.Sum([]byte(fmt.Sprintf("%s@%s@%d", user, hostport, t.UnixNano())))
return fmt.Sprintf("%X", sum[:5])
}
// SplitHostPort splits a network address of the form "host:port" or
// "host[:port]" into host and port. If the port is not specified the default
// ssh port ("22") is returned.
func SplitHostPort(hostport string) (string, string, error) {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
if err.(*net.AddrError).Err == "missing port in address" {
return hostport, DefaultSSHPort, nil
}
return hostport, DefaultSSHPort, err
}
portNum, err := net.LookupPort("tcp", port)
if err != nil {
return "", "", fmt.Errorf("address %s: invalid port", hostport)
}
return host, strconv.Itoa(portNum), nil
}
// Mocking user.Lookup and user.LookupGroupId for testing.
var userCurrent = user.Current
var userLookup = user.Lookup
var userLookupGroupId = user.LookupGroupId
// GetGroupUser returns a map of group memberships for the specified user.
//
// It can be used to quickly check if a user is in a specified group.
func GetGroupUser(u *user.User) (map[string]bool, error) {
groupids, err := u.GroupIds()
if err != nil {
return nil, err
}
groups := make(map[string]bool)
for _, gid := range groupids {
g, err := userLookupGroupId(gid)
if err != nil {
return nil, err
}
groups[g.Name] = true
}
return groups, nil
}
// GetGroups returns a map of group memberships for the current user.
//
// It can be used to quickly check if a user is in a specified group.
func GetGroups() (map[string]bool, error) {
u, err := userCurrent()
if err != nil {
return nil, err
}
groups, err := GetGroupUser(u)
if err != nil {
return nil, err
}
return groups, nil
}
// GetGroupList returns a map of group memberships for the specified user.
//
// It can be used to quickly check if a user is in a specified group.
func GetGroupList(username string) (map[string]bool, error) {
u, err := userLookup(username)
if err != nil {
return nil, err
}
groups, err := GetGroupUser(u)
if err != nil {
return nil, err
}
return groups, nil
}
// GetSortedGroups returns a string of sorted space-separated groups for the
// specified user.
//
// It displays a warning when a user has no group (happens when a user has been
// deleted, but still has an open connection
func GetSortedGroups(username string) string {
groups, err := GetGroupList(username)
if err != nil {
fmt.Fprintln(os.Stderr, err.Error())
return ""
} else {
g := make([]string, 0, len(groups))
for group := range groups {
g = append(g, group)
}
sort.Strings(g)
return strings.Join(g, " ")
}
}
// Mocking net.LookupHost for testing.
var netLookupHost = net.LookupHost
// normalizeHostPort returns a slice of strings of IPs (hostnames are resolved)
// and a string containing the port (defaults to 22)
func normalizeHostPort(hostPort string) ([]string, string, error) {
host, port, err := SplitHostPort(hostPort)
if err != nil {
return nil, "", fmt.Errorf("invalid address: %s", err)
}
var addrs []string
if net.ParseIP(host) == nil {
// host is a name and not an IP address
// this name can resolve to multiple IPs
addrs, err = netLookupHost(host)
if err != nil {
return nil, "", fmt.Errorf("cannot resolve host '%s': %v", host, err)
}
} else {
addrs = []string{host}
}
return addrs, port, nil
}
// MatchSource checks if a source matches the SSH host and port (defaults to
// 22) of the incoming ssh connection
func MatchSource(source string, sshdHostPort string) (bool, error) {
sourceAddrs, sourcePort, err := normalizeHostPort(source)
if err != nil {
return false, fmt.Errorf("source: %s", err)
}
sshdAddrs, sshdPort, err := normalizeHostPort(sshdHostPort)
if err != nil {
return false, fmt.Errorf("sshdHostPort: %s", err)
}
if sourcePort != sshdPort {
return false, nil
}
for _, sourceAddr := range sourceAddrs {
for _, sshdAddr := range sshdAddrs {
if sourceAddr == sshdAddr {
return true, nil
}
}
}
return false, nil
}