Skip to content

Commit 05266d6

Browse files
committed
Add NetFromRange and NetFromInterval helpers
When retrieving set elements it can be desired to format the result in the same way `nft` would, which is merging intervals to CIDR representations. To make this easier, introduce helper functions which allow for conversion of IP address ranges to CIDR networks. Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
1 parent 1db35da commit 05266d6

2 files changed

Lines changed: 305 additions & 0 deletions

File tree

util.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,19 @@ package nftables
1616

1717
import (
1818
"encoding/binary"
19+
"errors"
1920
"net"
21+
"net/netip"
2022

2123
"github.com/google/nftables/binaryutil"
2224
"golang.org/x/sys/unix"
2325
)
2426

27+
var (
28+
MaxIPv4 = net.IP{255, 255, 255, 255}
29+
MaxIPv6 = net.IP{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
30+
)
31+
2532
func extraHeader(family uint8, resID uint16) []byte {
2633
return append([]byte{
2734
family,
@@ -126,3 +133,83 @@ func NetInterval(cidr string) (net.IP, net.IP, error) {
126133

127134
return first, nextIP(last), nil
128135
}
136+
137+
// endIp returns the last address in a given network.
138+
func endIp(netIp net.IP, mask net.IPMask) net.IP {
139+
ip := make(net.IP, len(netIp))
140+
copy(ip, netIp)
141+
142+
for i := 0; i < len(mask); i++ {
143+
ipIdx := len(ip) - i - 1
144+
ip[ipIdx] = netIp[ipIdx] | ^mask[len(mask)-i-1]
145+
}
146+
147+
return ip
148+
}
149+
150+
// NetFromRange returns a CIDR IP network given a start and end address.
151+
// If the network is an exact match, ok will be true.
152+
func NetFromRange(first net.IP, last net.IP) (*net.IPNet, bool, error) {
153+
ip1 := net.IP(first)
154+
ip2 := net.IP(last)
155+
156+
maxLen := 32
157+
isIpv6 := ip1.To4() == nil
158+
159+
if isIpv6 && ip2.To4() != nil || !isIpv6 && ip2.To4() == nil {
160+
return nil, false, errors.New("Cannot mix IPv4 and IPv6 or process empty IP.")
161+
}
162+
163+
if isIpv6 {
164+
maxLen = 128
165+
}
166+
167+
var match *net.IPNet
168+
for l := maxLen; l >= -1; l-- {
169+
cidrmask := net.CIDRMask(l, maxLen)
170+
ipmask := ip2.Mask(cidrmask)
171+
ipnet := net.IPNet{
172+
IP: ipmask,
173+
Mask: cidrmask,
174+
}
175+
176+
if ipnet.Contains(ip1) {
177+
match = &ipnet
178+
break
179+
}
180+
181+
}
182+
183+
matchFirst := match.IP.Equal(ip1)
184+
185+
// short-circuit if first address is not start of the network
186+
if !matchFirst {
187+
return match, matchFirst, nil
188+
}
189+
190+
return match, endIp(match.IP, match.Mask).Equal(ip2), nil
191+
}
192+
193+
// NetFromInterval returns a CIDR IP network given a start and end address as found in intervals.
194+
// This is similar to NetFromRange, but subtracts one address from the end of the range.
195+
// If the resulting network is an exact match, ok will be true.
196+
func NetFromInterval(first net.IP, last net.IP) (out *net.IPNet, ok bool, err error) {
197+
var previous net.IP
198+
199+
if len(last) == 0 {
200+
if first.To4() == nil {
201+
previous = MaxIPv6
202+
} else {
203+
previous = MaxIPv4
204+
}
205+
} else {
206+
ip2, ok := netip.AddrFromSlice(last)
207+
if !ok {
208+
return nil, false, errors.New("Failed to construct slice from network.")
209+
}
210+
211+
previous = ip2.Prev().AsSlice()
212+
}
213+
214+
return NetFromRange(first, previous)
215+
}

util_test.go

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,221 @@ func TestNetInterval(t *testing.T) {
201201
})
202202
}
203203
}
204+
205+
func TestEndIp(t *testing.T) {
206+
tests := []struct {
207+
network string
208+
wantEndIp string
209+
}{
210+
{
211+
network: "10.0.0.0/24",
212+
wantEndIp: "10.0.0.255",
213+
},
214+
{
215+
network: "192.168.4.32/27",
216+
wantEndIp: "192.168.4.63",
217+
},
218+
{
219+
network: "2001:db8:100::/64",
220+
wantEndIp: "2001:db8:100:0:ffff:ffff:ffff:ffff",
221+
},
222+
{
223+
network: "2001:db8:100:a:b::50/64",
224+
wantEndIp: "2001:db8:100:a:ffff:ffff:ffff:ffff",
225+
},
226+
}
227+
for _, tt := range tests {
228+
taddr, tnet, err := net.ParseCIDR(tt.network)
229+
if err != nil {
230+
t.Fatalf("endIp() error parsing test CIDR = %v", err)
231+
}
232+
233+
t.Run(tnet.String(), func(t *testing.T) {
234+
gotEndIp := endIp(taddr, tnet.Mask)
235+
if !gotEndIp.Equal(net.ParseIP(tt.wantEndIp)) {
236+
t.Errorf("endIp() gotEndIp = %s, wantEndIp = %s", gotEndIp, tt.wantEndIp)
237+
}
238+
})
239+
}
240+
}
241+
242+
func TestNetFromRange(t *testing.T) {
243+
tests := []struct {
244+
name string
245+
first string
246+
last string
247+
wantNet string
248+
wantOk bool
249+
wantErr bool
250+
}{
251+
{
252+
first: "0.0.0.0",
253+
last: "255.255.255.255",
254+
wantNet: "0.0.0.0/0",
255+
wantOk: true,
256+
wantErr: false,
257+
},
258+
{
259+
first: "0.0.0.1",
260+
last: "255.255.255.254",
261+
wantNet: "0.0.0.0/0",
262+
wantOk: false,
263+
wantErr: false,
264+
},
265+
{
266+
first: "192.168.4.0",
267+
last: "192.168.4.255",
268+
wantNet: "192.168.4.0/24",
269+
wantOk: true,
270+
wantErr: false,
271+
},
272+
{
273+
first: "192.0.2.16",
274+
last: "192.0.2.30",
275+
wantNet: "192.0.2.16/28",
276+
wantOk: false,
277+
wantErr: false,
278+
},
279+
{
280+
first: "2001:db8:100::",
281+
last: "2001:db8:100:ffff:ffff:ffff:ffff:ffff",
282+
wantNet: "2001:db8:100::/48",
283+
wantOk: true,
284+
wantErr: false,
285+
},
286+
{
287+
first: "2001:db8:100::100",
288+
last: "2001:db8:100:0:ffff:ffff:ffff:ffff",
289+
wantNet: "2001:db8:100::/64",
290+
wantOk: false,
291+
wantErr: false,
292+
},
293+
{
294+
first: "2001:db8:100::",
295+
last: "192.0.2.30",
296+
wantNet: "",
297+
wantOk: true,
298+
wantErr: true,
299+
},
300+
{
301+
first: "192.0.2.30",
302+
last: "2001:db8:100::",
303+
wantNet: "",
304+
wantOk: true,
305+
wantErr: true,
306+
},
307+
}
308+
309+
for _, tt := range tests {
310+
t.Run(tt.first+"-"+tt.last, func(t *testing.T) {
311+
gotNet, gotOk, err := NetFromRange(net.ParseIP(tt.first), net.ParseIP(tt.last))
312+
if (err != nil) != tt.wantErr {
313+
t.Errorf("NetFromRange() error = %v, wantErr = %v", err, tt.wantErr)
314+
}
315+
316+
if tt.wantNet == "" {
317+
return
318+
}
319+
320+
_, wantNetParsed, err := net.ParseCIDR(tt.wantNet)
321+
if err != nil {
322+
t.Fatalf("NetFromRange() error parsing test network = %v", err)
323+
}
324+
325+
if tt.wantOk != gotOk {
326+
t.Errorf("NetFromRange() gotOk = %t, wantOk = %t", gotOk, tt.wantOk)
327+
}
328+
329+
if !reflect.DeepEqual(gotNet, wantNetParsed) {
330+
t.Errorf("NetFromRange() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed)
331+
}
332+
})
333+
}
334+
}
335+
336+
func TestNetFromInterval(t *testing.T) {
337+
tests := []struct {
338+
name string
339+
first string
340+
last string
341+
wantNet string
342+
wantOk bool
343+
wantErr bool
344+
}{
345+
{
346+
first: "192.0.2.16",
347+
last: "192.0.2.32",
348+
wantNet: "192.0.2.16/28",
349+
wantOk: true,
350+
wantErr: false,
351+
},
352+
{
353+
first: "128.0.0.0",
354+
last: "",
355+
wantNet: "128.0.0.0/1",
356+
wantOk: true,
357+
wantErr: false,
358+
},
359+
{
360+
first: "2001:db8:100::",
361+
last: "2001:db8:101::",
362+
wantNet: "2001:db8:100::/48",
363+
wantOk: true,
364+
wantErr: false,
365+
},
366+
{
367+
first: "2001:db8:a1:11::",
368+
last: "2001:db8:a1:12::",
369+
wantNet: "2001:db8:a1:11::/64",
370+
wantOk: true,
371+
wantErr: false,
372+
},
373+
{
374+
first: "2001:db8:100::100",
375+
last: "2001:db8:100:0:ffff:ffff:ffff:ffff",
376+
wantNet: "2001:db8:100::/64",
377+
wantOk: false,
378+
wantErr: false,
379+
},
380+
{
381+
first: "2001:db8:100::",
382+
last: "192.0.2.30",
383+
wantNet: "",
384+
wantOk: true,
385+
wantErr: true,
386+
},
387+
{
388+
first: "192.0.2.30",
389+
last: "2001:db8:100::",
390+
wantNet: "",
391+
wantOk: true,
392+
wantErr: true,
393+
},
394+
}
395+
396+
for _, tt := range tests {
397+
t.Run(tt.first+"-"+tt.last, func(t *testing.T) {
398+
gotNet, gotOk, err := NetFromInterval(net.ParseIP(tt.first), net.ParseIP(tt.last))
399+
if (err != nil) != tt.wantErr {
400+
t.Errorf("NetFromInterval() error = %v, wantErr = %v", err, tt.wantErr)
401+
}
402+
403+
if tt.wantNet == "" {
404+
return
405+
}
406+
407+
_, wantNetParsed, err := net.ParseCIDR(tt.wantNet)
408+
if err != nil {
409+
t.Fatalf("NetFromInterval() error parsing test network = %v", err)
410+
}
411+
412+
if tt.wantOk != gotOk {
413+
t.Errorf("NetFromInterval() gotOk = %t, wantOk = %t", gotOk, tt.wantOk)
414+
}
415+
416+
if !reflect.DeepEqual(gotNet, wantNetParsed) {
417+
t.Errorf("NetFromInterval() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed)
418+
}
419+
})
420+
}
421+
}

0 commit comments

Comments
 (0)