diff --git a/Package.resolved b/Package.resolved index 3ea06388b..3b29d5751 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "332c365348153b80a80666112de5736d8c1f9f66e57bfcbf01390bf0b1c062f0", + "originHash" : "4ec05f4e83999a89d3397d0657536924d4a425d7f0e3f0fd6a3578e34c924502", "pins" : [ { "identity" : "async-http-client", @@ -19,24 +19,6 @@ "version" : "0.29.0" } }, - { - "identity" : "dns", - "kind" : "remoteSourceControl", - "location" : "https://github.com/Bouke/DNS.git", - "state" : { - "revision" : "78bbd1589890a90b202d11d5f9e1297050cf0eb2", - "version" : "1.2.0" - } - }, - { - "identity" : "dnsclient", - "kind" : "remoteSourceControl", - "location" : "https://github.com/orlandos-nl/DNSClient.git", - "state" : { - "revision" : "551fbddbf4fa728d4cd86f6a5208fe4f925f0549", - "version" : "2.4.4" - } - }, { "identity" : "grpc-swift", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 1d916cd52..1bbb3c309 100644 --- a/Package.swift +++ b/Package.swift @@ -47,7 +47,6 @@ let package = Package( .library(name: "TerminalProgress", targets: ["TerminalProgress"]), ], dependencies: [ - .package(url: "https://github.com/Bouke/DNS.git", from: "1.2.0"), .package(url: "https://github.com/apple/containerization.git", exact: Version(stringLiteral: scVersion)), .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"), .package(url: "https://github.com/apple/swift-collections.git", from: "1.2.0"), @@ -56,7 +55,6 @@ let package = Package( .package(url: "https://github.com/apple/swift-protobuf.git", from: "1.29.0"), .package(url: "https://github.com/apple/swift-system.git", from: "1.4.0"), .package(url: "https://github.com/grpc/grpc-swift.git", from: "1.26.0"), - .package(url: "https://github.com/orlandos-nl/DNSClient.git", from: "2.4.1"), .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.20.1"), .package(url: "https://github.com/swiftlang/swift-docc-plugin.git", from: "1.1.0"), ], @@ -427,17 +425,15 @@ let package = Package( dependencies: [ .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), - .product(name: "DNSClient", package: "DNSClient"), - .product(name: "DNS", package: "DNS"), .product(name: "Logging", package: "swift-log"), + .product(name: "ContainerizationExtras", package: "containerization"), .product(name: "ContainerizationOS", package: "containerization"), ] ), .testTarget( name: "DNSServerTests", dependencies: [ - .product(name: "DNS", package: "DNS"), - "DNSServer", + "DNSServer" ] ), .testTarget( diff --git a/Sources/DNSServer/DNSServer+Handle.swift b/Sources/DNSServer/DNSServer+Handle.swift index e258fb482..9b2bc9fa5 100644 --- a/Sources/DNSServer/DNSServer+Handle.swift +++ b/Sources/DNSServer/DNSServer+Handle.swift @@ -27,23 +27,32 @@ extension DNSServer { outbound: NIOAsyncChannelOutboundWriter>, packet: inout AddressedEnvelope ) async throws { - let chunkSize = 512 - var data = Data() + // RFC 1035 §2.3.4 limits UDP DNS messages to 512 bytes. We don't implement + // EDNS0 (RFC 6891), and this server only resolves host A/AAAA queries, so a + // legitimate query will never approach this limit. Reject oversized packets + // before reading to avoid allocating memory for malformed or malicious datagrams. + let maxPacketSize = 512 + guard packet.data.readableBytes <= maxPacketSize else { + self.log?.error("dropping oversized DNS packet: \(packet.data.readableBytes) bytes") + return + } + var data = Data() self.log?.debug("reading data") while packet.data.readableBytes > 0 { - if let chunk = packet.data.readBytes(length: min(chunkSize, packet.data.readableBytes)) { + if let chunk = packet.data.readBytes(length: packet.data.readableBytes) { data.append(contentsOf: chunk) } } self.log?.debug("deserializing message") - let query = try Message(deserialize: data) - self.log?.debug("processing query: \(query.questions)") // always send response let responseData: Data do { + let query = try Message(deserialize: data) + self.log?.debug("processing query: \(query.questions)") + self.log?.debug("awaiting processing") var response = try await handler.answer(query: query) @@ -64,21 +73,48 @@ extension DNSServer { self.log?.debug("serializing response") responseData = try response.serialize() + } catch let error as DNSBindError { + // Best-effort: echo the transaction ID from the first two bytes of the raw packet. + let rawId = data.count >= 2 ? data[0..<2].withUnsafeBytes { $0.load(as: UInt16.self) } : 0 + let id = UInt16(bigEndian: rawId) + let returnCode: ReturnCode + switch error { + case .unsupportedValue: + self.log?.error("not implemented processing DNS message: \(error)") + returnCode = .notImplemented + default: + self.log?.error("format error processing DNS message: \(error)") + returnCode = .formatError + } + let response = Message( + id: id, + type: .response, + returnCode: returnCode, + questions: [], + answers: [] + ) + responseData = try response.serialize() } catch { - self.log?.error("error processing message from \(query): \(error)") + let rawId = data.count >= 2 ? data[0..<2].withUnsafeBytes { $0.load(as: UInt16.self) } : 0 + let id = UInt16(bigEndian: rawId) + self.log?.error("error processing DNS message: \(error)") let response = Message( - id: query.id, + id: id, type: .response, - returnCode: .notImplemented, - questions: query.questions, + returnCode: .serverFailure, + questions: [], answers: [] ) responseData = try response.serialize() } - self.log?.debug("sending response for \(query.id)") + self.log?.debug("sending response") let rData = ByteBuffer(bytes: responseData) - try? await outbound.write(AddressedEnvelope(remoteAddress: packet.remoteAddress, data: rData)) + do { + try await outbound.write(AddressedEnvelope(remoteAddress: packet.remoteAddress, data: rData)) + } catch { + self.log?.error("failed to send DNS response: \(error)") + } self.log?.debug("processing done") diff --git a/Sources/DNSServer/Handlers/HostTableResolver.swift b/Sources/DNSServer/Handlers/HostTableResolver.swift index 0bc247609..41a679677 100644 --- a/Sources/DNSServer/Handlers/HostTableResolver.swift +++ b/Sources/DNSServer/Handlers/HostTableResolver.swift @@ -14,30 +14,44 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras /// Handler that uses table lookup to resolve hostnames. +/// +/// Keys in `hosts4` are normalized to `DNSName` on construction, so lookups +/// are case-insensitive and trailing dots are optional. public struct HostTableResolver: DNSHandler { - public let hosts4: [String: IPv4] + public let hosts4: [DNSName: IPv4Address] private let ttl: UInt32 - public init(hosts4: [String: IPv4], ttl: UInt32 = 300) { - self.hosts4 = hosts4 + /// Creates a resolver backed by a static IPv4 host table. + /// + /// - Parameter hosts4: A dictionary mapping domain names to IPv4 addresses. + /// Keys are normalized to `DNSName` (lowercased, trailing dot stripped), so + /// `"FOO."`, `"foo."`, and `"foo"` all refer to the same entry. + /// - Parameter ttl: The TTL in seconds to set on answer records (default is 300). + /// - Throws: `DNSBindError.invalidName` if any key is not a valid DNS name. + public init(hosts4: [String: IPv4Address], ttl: UInt32 = 300) throws { + self.hosts4 = try Dictionary(uniqueKeysWithValues: hosts4.map { (try DNSName($0.key), $0.value) }) self.ttl = ttl } public func answer(query: Message) async throws -> Message? { - let question = query.questions[0] + guard let question = query.questions.first else { + return nil + } + let n = question.name.hasSuffix(".") ? String(question.name.dropLast()) : question.name + let key = try DNSName(labels: n.isEmpty ? [] : n.split(separator: ".", omittingEmptySubsequences: false).map(String.init)) let record: ResourceRecord? switch question.type { case ResourceRecordType.host: - record = answerHost(question: question) + record = answerHost(question: question, key: key) case ResourceRecordType.host6: // Return NODATA (noError with empty answers) for AAAA queries ONLY if A record exists. // This is required because musl libc has issues when A record exists but AAAA returns NXDOMAIN. // musl treats NXDOMAIN on AAAA as "domain doesn't exist" and fails DNS resolution entirely. // NODATA correctly indicates "no IPv6 address available, but domain exists". - if hosts4[question.name] != nil { + if hosts4[key] != nil { return Message( id: query.id, type: .response, @@ -48,28 +62,11 @@ public struct HostTableResolver: DNSHandler { } // If hostname doesn't exist, return nil which will become NXDOMAIN return nil - case ResourceRecordType.nameServer, - ResourceRecordType.alias, - ResourceRecordType.startOfAuthority, - ResourceRecordType.pointer, - ResourceRecordType.mailExchange, - ResourceRecordType.text, - ResourceRecordType.service, - ResourceRecordType.incrementalZoneTransfer, - ResourceRecordType.standardZoneTransfer, - ResourceRecordType.all: - return Message( - id: query.id, - type: .response, - returnCode: .notImplemented, - questions: query.questions, - answers: [] - ) default: return Message( id: query.id, type: .response, - returnCode: .formatError, + returnCode: .notImplemented, questions: query.questions, answers: [] ) @@ -88,11 +85,11 @@ public struct HostTableResolver: DNSHandler { ) } - private func answerHost(question: Question) -> ResourceRecord? { - guard let ip = hosts4[question.name] else { + private func answerHost(question: Question, key: DNSName) -> ResourceRecord? { + guard let ip = hosts4[key] else { return nil } - return HostRecord(name: question.name, ttl: ttl, ip: ip) + return HostRecord(name: question.name, ttl: ttl, ip: ip) } } diff --git a/Sources/DNSServer/Handlers/NxDomainResolver.swift b/Sources/DNSServer/Handlers/NxDomainResolver.swift index 68b36bf1d..8fa9c05b4 100644 --- a/Sources/DNSServer/Handlers/NxDomainResolver.swift +++ b/Sources/DNSServer/Handlers/NxDomainResolver.swift @@ -14,8 +14,6 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS - /// Handler that returns NXDOMAIN for all hostnames. public struct NxDomainResolver: DNSHandler { private let ttl: UInt32 @@ -35,29 +33,11 @@ public struct NxDomainResolver: DNSHandler { questions: query.questions, answers: [] ) - case ResourceRecordType.nameServer, - ResourceRecordType.alias, - ResourceRecordType.startOfAuthority, - ResourceRecordType.pointer, - ResourceRecordType.mailExchange, - ResourceRecordType.text, - ResourceRecordType.host6, - ResourceRecordType.service, - ResourceRecordType.incrementalZoneTransfer, - ResourceRecordType.standardZoneTransfer, - ResourceRecordType.all: - return Message( - id: query.id, - type: .response, - returnCode: .notImplemented, - questions: query.questions, - answers: [] - ) default: return Message( id: query.id, type: .response, - returnCode: .formatError, + returnCode: .notImplemented, questions: query.questions, answers: [] ) diff --git a/Sources/DNSServer/Records/DNSBindError.swift b/Sources/DNSServer/Records/DNSBindError.swift new file mode 100644 index 000000000..b054d42e7 --- /dev/null +++ b/Sources/DNSServer/Records/DNSBindError.swift @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +/// Errors that can occur during DNS message serialization/deserialization. +public enum DNSBindError: Error, CustomStringConvertible { + case marshalFailure(type: String, field: String) + case unmarshalFailure(type: String, field: String) + case unsupportedValue(type: String, field: String) + case invalidName(String) + case unexpectedOffset(type: String, expected: Int, actual: Int) + + public var description: String { + switch self { + case .marshalFailure(let type, let field): + return "failed to marshal \(type).\(field)" + case .unmarshalFailure(let type, let field): + return "failed to unmarshal \(type).\(field)" + case .unsupportedValue(let type, let field): + return "unsupported value for \(type).\(field)" + case .invalidName(let reason): + return "invalid DNS name: \(reason)" + case .unexpectedOffset(let type, let expected, let actual): + return "unexpected offset serializing \(type): expected \(expected), got \(actual)" + } + } +} diff --git a/Sources/DNSServer/Records/DNSEnums.swift b/Sources/DNSServer/Records/DNSEnums.swift new file mode 100644 index 000000000..abea14a58 --- /dev/null +++ b/Sources/DNSServer/Records/DNSEnums.swift @@ -0,0 +1,167 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +/// DNS message type (query or response). +public enum MessageType: UInt16, Sendable { + case query = 0 + case response = 1 +} + +/// DNS operation code (RFC 1035, 1996, 2136). +public enum OperationCode: UInt8, Sendable { + case query = 0 // Standard query (RFC 1035) + case inverseQuery = 1 // Inverse query (obsolete, RFC 3425) + case status = 2 // Server status request (RFC 1035) + // 3 is reserved + case notify = 4 // Zone change notification (RFC 1996) + case update = 5 // Dynamic update (RFC 2136) + case dso = 6 // DNS Stateful Operations (RFC 8490) + // 7-15 reserved +} + +/// DNS response return codes (RFC 1035, 2136, 2845, 6895). +public enum ReturnCode: UInt8, Sendable { + case noError = 0 // No error + case formatError = 1 // Format error - unable to interpret query + case serverFailure = 2 // Server failure + case nonExistentDomain = 3 // Name error - domain does not exist (NXDOMAIN) + case notImplemented = 4 // Not implemented - query type not supported + case refused = 5 // Refused - policy restriction + case yxDomain = 6 // Name exists when it should not (RFC 2136) + case yxRRSet = 7 // RR set exists when it should not (RFC 2136) + case nxRRSet = 8 // RR set does not exist when it should (RFC 2136) + case notAuthoritative = 9 // Server not authoritative (RFC 2136) / Not authorized (RFC 2845) + case notZone = 10 // Name not in zone (RFC 2136) + case dsoTypeNotImplemented = 11 // DSO-TYPE not implemented (RFC 8490) + // 12-15 reserved + case badSignature = 16 // TSIG signature failure (RFC 2845) + case badKey = 17 // Key not recognized (RFC 2845) + case badTime = 18 // Signature out of time window (RFC 2845) + case badMode = 19 // Bad TKEY mode (RFC 2930) + case badName = 20 // Duplicate key name (RFC 2930) + case badAlgorithm = 21 // Algorithm not supported (RFC 2930) + case badTruncation = 22 // Bad truncation (RFC 4635) + case badCookie = 23 // Bad/missing server cookie (RFC 7873) +} + +/// DNS resource record types (RFC 1035, 3596, 2782, and others). +public enum ResourceRecordType: UInt16, Sendable { + case host = 1 // A - IPv4 address (RFC 1035) + case nameServer = 2 // NS - Authoritative name server (RFC 1035) + case mailDestination = 3 // MD - Mail destination (obsolete, RFC 1035) + case mailForwarder = 4 // MF - Mail forwarder (obsolete, RFC 1035) + case alias = 5 // CNAME - Canonical name (RFC 1035) + case startOfAuthority = 6 // SOA - Start of authority (RFC 1035) + case mailbox = 7 // MB - Mailbox domain name (experimental, RFC 1035) + case mailGroup = 8 // MG - Mail group member (experimental, RFC 1035) + case mailRename = 9 // MR - Mail rename domain name (experimental, RFC 1035) + case null = 10 // NULL - Null RR (experimental, RFC 1035) + case wellKnownService = 11 // WKS - Well known service (RFC 1035) + case pointer = 12 // PTR - Domain name pointer (RFC 1035) + case hostInfo = 13 // HINFO - Host information (RFC 1035) + case mailInfo = 14 // MINFO - Mailbox information (RFC 1035) + case mailExchange = 15 // MX - Mail exchange (RFC 1035) + case text = 16 // TXT - Text strings (RFC 1035) + case responsiblePerson = 17 // RP - Responsible person (RFC 1183) + case afsDatabase = 18 // AFSDB - AFS database location (RFC 1183) + case x25 = 19 // X25 - X.25 PSDN address (RFC 1183) + case isdn = 20 // ISDN - ISDN address (RFC 1183) + case routeThrough = 21 // RT - Route through (RFC 1183) + case nsapAddress = 22 // NSAP - NSAP address (RFC 1706) + case nsapPointer = 23 // NSAP-PTR - NSAP pointer (RFC 1706) + case signature = 24 // SIG - Security signature (RFC 2535) + case key = 25 // KEY - Security key (RFC 2535) + case pxRecord = 26 // PX - X.400 mail mapping (RFC 2163) + case gpos = 27 // GPOS - Geographical position (RFC 1712) + case host6 = 28 // AAAA - IPv6 address (RFC 3596) + case location = 29 // LOC - Location information (RFC 1876) + case nextDomain = 30 // NXT - Next domain (obsolete, RFC 2535) + case endpointId = 31 // EID - Endpoint identifier + case nimrodLocator = 32 // NIMLOC - Nimrod locator + case service = 33 // SRV - Service locator (RFC 2782) + case atma = 34 // ATMA - ATM address + case namingPointer = 35 // NAPTR - Naming authority pointer (RFC 3403) + case keyExchange = 36 // KX - Key exchange (RFC 2230) + case cert = 37 // CERT - Certificate (RFC 4398) + case a6Record = 38 // A6 - IPv6 address (obsolete, RFC 2874) + case dname = 39 // DNAME - Delegation name (RFC 6672) + case sink = 40 // SINK - Kitchen sink + case opt = 41 // OPT - EDNS option (RFC 6891) + case apl = 42 // APL - Address prefix list (RFC 3123) + case delegationSigner = 43 // DS - Delegation signer (RFC 4034) + case sshFingerprint = 44 // SSHFP - SSH key fingerprint (RFC 4255) + case ipsecKey = 45 // IPSECKEY - IPsec key (RFC 4025) + case resourceSignature = 46 // RRSIG - Resource record signature (RFC 4034) + case nsec = 47 // NSEC - Next secure record (RFC 4034) + case dnsKey = 48 // DNSKEY - DNS key (RFC 4034) + case dhcid = 49 // DHCID - DHCP identifier (RFC 4701) + case nsec3 = 50 // NSEC3 - NSEC3 (RFC 5155) + case nsec3Param = 51 // NSEC3PARAM - NSEC3 parameters (RFC 5155) + case tlsa = 52 // TLSA - TLSA certificate (RFC 6698) + case smimea = 53 // SMIMEA - S/MIME cert association (RFC 8162) + // 54 unassigned + case hip = 55 // HIP - Host identity protocol (RFC 8005) + case ninfo = 56 // NINFO + case rkey = 57 // RKEY + case taLink = 58 // TALINK - Trust anchor link + case cds = 59 // CDS - Child DS (RFC 7344) + case cdnsKey = 60 // CDNSKEY - Child DNSKEY (RFC 7344) + case openPGPKey = 61 // OPENPGPKEY - OpenPGP key (RFC 7929) + case csync = 62 // CSYNC - Child-to-parent sync (RFC 7477) + case zoneDigest = 63 // ZONEMD - Zone message digest (RFC 8976) + case svcBinding = 64 // SVCB - Service binding (RFC 9460) + case httpsBinding = 65 // HTTPS - HTTPS binding (RFC 9460) + // 66-98 unassigned + case spf = 99 // SPF - Sender policy framework (RFC 7208) + case uinfo = 100 // UINFO + case uid = 101 // UID + case gid = 102 // GID + case unspec = 103 // UNSPEC + case nid = 104 // NID - Node identifier (RFC 6742) + case l32 = 105 // L32 - Locator32 (RFC 6742) + case l64 = 106 // L64 - Locator64 (RFC 6742) + case lp = 107 // LP - Locator FQDN (RFC 6742) + case eui48 = 108 // EUI48 - 48-bit MAC (RFC 7043) + case eui64 = 109 // EUI64 - 64-bit MAC (RFC 7043) + // 110-248 unassigned + case tkey = 249 // TKEY - Transaction key (RFC 2930) + case tsig = 250 // TSIG - Transaction signature (RFC 2845) + case incrementalZoneTransfer = 251 // IXFR - Incremental zone transfer (RFC 1995) + case standardZoneTransfer = 252 // AXFR - Full zone transfer (RFC 1035) + case mailboxRecords = 253 // MAILB - Mailbox-related records (RFC 1035) + case mailAgentRecords = 254 // MAILA - Mail agent RRs (obsolete, RFC 1035) + case all = 255 // * - All records (RFC 1035) + case uri = 256 // URI - Uniform resource identifier (RFC 7553) + case caa = 257 // CAA - Certification authority authorization (RFC 8659) + case avc = 258 // AVC - Application visibility and control + case doa = 259 // DOA - Digital object architecture + case amtRelay = 260 // AMTRELAY - Automatic multicast tunneling relay (RFC 8777) + case resInfo = 261 // RESINFO - Resolver information + // ... + case ta = 32768 // TA - DNSSEC trust authorities + case dlv = 32769 // DLV - DNSSEC lookaside validation (RFC 4431) +} + +/// DNS resource record class (RFC 1035). +public enum ResourceRecordClass: UInt16, Sendable { + case internet = 1 // IN - Internet (RFC 1035) + // 2 unassigned + case chaos = 3 // CH - Chaos (RFC 1035) + case hesiod = 4 // HS - Hesiod (RFC 1035) + // 5-253 unassigned + case none = 254 // NONE - None (RFC 2136) + case any = 255 // * - Any class (RFC 1035) +} diff --git a/Sources/DNSServer/Records/DNSName.swift b/Sources/DNSServer/Records/DNSName.swift new file mode 100644 index 000000000..32d277a33 --- /dev/null +++ b/Sources/DNSServer/Records/DNSName.swift @@ -0,0 +1,204 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import Foundation + +/// A DNS name encoded as a sequence of labels. +/// +/// DNS names are encoded as: `[length][label][length][label]...[0]` +/// For example, "example.com" becomes: `[7]example[3]com[0]` +public struct DNSName: Sendable, Hashable, CustomStringConvertible { + /// The labels that make up this name (e.g., ["example", "com"]). + public private(set) var labels: [String] + + /// Creates a DNS name representing the root (empty label list). + public init() { + self.labels = [] + } + + /// Creates a validated DNS name from an array of labels. + /// + /// Validates structural RFC 1035 constraints only: no empty labels, each label ≤ 63 + /// bytes, total wire length ≤ 255 bytes. Does not enforce hostname character rules. + /// Labels are lowercased to normalize for case-insensitive DNS comparison. + /// + /// - Throws: `DNSBindError.invalidName` if any label is empty, exceeds 63 bytes, + /// or if the total wire representation exceeds 255 bytes. + public init(labels: [String]) throws { + for label in labels { + guard !label.isEmpty else { + throw DNSBindError.invalidName("empty label") + } + guard label.utf8.count <= 63 else { + throw DNSBindError.invalidName("label too long: \"\(label)\"") + } + } + let wireLength = labels.reduce(1) { $0 + 1 + $1.utf8.count } + guard wireLength <= 255 else { + throw DNSBindError.invalidName("name too long") + } + self.labels = labels.map { $0.lowercased() } + } + + /// Creates a validated DNS name from a dot-separated hostname string + /// (e.g., `"example.com."` or `"example.com"`). + /// + /// A trailing dot is accepted but not required. + /// An empty string produces the root name without error. + /// + /// Labels must start and end with a letter or digit (LDH hostname rule). + /// Use `init(labels:)` directly when working with wire-decoded names that + /// may contain non-hostname labels (e.g. service-discovery labels like `"_dns"`). + /// + /// - Throws: `DNSBindError.invalidName` if any label violates the character rules, + /// or if structural limits are exceeded (see `init(labels:)`). + public init(_ hostname: String) throws { + let normalized = hostname.hasSuffix(".") ? String(hostname.dropLast()) : hostname + guard !normalized.isEmpty else { + self.init() + return + } + let parts = normalized.split(separator: ".", omittingEmptySubsequences: false).map { String($0) } + let hostnameRegex = /[a-zA-Z0-9](?:[a-zA-Z0-9\-_]*[a-zA-Z0-9])?/ + for part in parts { + guard part.wholeMatch(of: hostnameRegex) != nil else { + throw DNSBindError.invalidName( + "label must start and end with a letter or digit: \"\(part)\"" + ) + } + } + try self.init(labels: parts) + } + + /// The wire format size of this name in bytes. + public var size: Int { + // Each label: 1 byte length + label bytes, plus 1 byte for null terminator + labels.reduce(1) { $0 + 1 + $1.utf8.count } + } + + /// The fully-qualified domain name with trailing dot. + public var description: String { + labels.isEmpty ? "." : labels.joined(separator: ".") + "." + } + + /// Serialize this name into the buffer at the given offset. + public func appendBuffer(_ buffer: inout [UInt8], offset: Int) throws -> Int { + let startOffset = offset + var offset = offset + + for label in labels { + let bytes = Array(label.utf8) + guard bytes.count <= 63 else { + throw DNSBindError.marshalFailure(type: "DNSName", field: "label") + } + + guard let newOffset = buffer.copyIn(as: UInt8.self, value: UInt8(bytes.count), offset: offset) else { + throw DNSBindError.marshalFailure(type: "DNSName", field: "label") + } + offset = newOffset + + guard let newOffset = buffer.copyIn(buffer: bytes, offset: offset) else { + throw DNSBindError.marshalFailure(type: "DNSName", field: "label") + } + offset = newOffset + } + + // Null terminator + guard let newOffset = buffer.copyIn(as: UInt8.self, value: 0, offset: offset) else { + throw DNSBindError.marshalFailure(type: "DNSName", field: "terminator") + } + + guard newOffset == startOffset + size else { + throw DNSBindError.unexpectedOffset(type: "DNSName", expected: startOffset + size, actual: newOffset) + } + return newOffset + } + + /// Deserialize a name from the buffer at the given offset. + /// + /// - Parameters: + /// - buffer: The buffer to read from. + /// - offset: The offset to start reading. + /// - messageStart: The start of the DNS message (for compression pointer resolution). + /// - Returns: The new offset after reading. + public mutating func bindBuffer( + _ buffer: inout [UInt8], + offset: Int, + messageStart: Int = 0 + ) throws -> Int { + var offset = offset + var collectedLabels: [String] = [] + var jumped = false + var returnOffset = offset + var pointerHops = 0 + + while true { + guard offset < buffer.count else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "name") + } + + let length = buffer[offset] + + // Check for compression pointer (top 2 bits set) + if (length & 0xC0) == 0xC0 { + guard offset + 1 < buffer.count else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "pointer") + } + + pointerHops += 1 + guard pointerHops <= 10 else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "pointer") + } + + if !jumped { + returnOffset = offset + 2 + } + + // Calculate pointer offset from message start + let pointer = Int(length & 0x3F) << 8 | Int(buffer[offset + 1]) + let pointerTarget = messageStart + pointer + guard pointerTarget >= 0 && pointerTarget < offset && pointerTarget < buffer.count else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "pointer") + } + offset = pointerTarget + jumped = true + continue + } + + offset += 1 + + // Null terminator - end of name + if length == 0 { + break + } + + guard offset + Int(length) <= buffer.count else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "label") + } + + let labelBytes = Array(buffer[offset..> 11) & 0x0F)) else { + throw DNSBindError.unsupportedValue(type: "Message", field: "opcode") + } + self.operationCode = opCode + self.authoritativeAnswer = (flags & 0x0400) != 0 + self.truncation = (flags & 0x0200) != 0 + self.recursionDesired = (flags & 0x0100) != 0 + self.recursionAvailable = (flags & 0x0080) != 0 + guard let returnCode = ReturnCode(rawValue: UInt8(flags & 0x000F)) else { + throw DNSBindError.unsupportedValue(type: "Message", field: "rcode") + } + self.returnCode = returnCode + + // Read counts + guard let (newOffset, rawQdCount) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Message", field: "qdcount") + } + let qdCount = UInt16(bigEndian: rawQdCount) + offset = newOffset + + guard let (newOffset, rawAnCount) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Message", field: "ancount") + } + let anCount = UInt16(bigEndian: rawAnCount) + offset = newOffset + + guard let (newOffset, rawNsCount) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Message", field: "nscount") + } + // nsCount not used for now, but we need to read past it + _ = UInt16(bigEndian: rawNsCount) + offset = newOffset + + guard let (newOffset, rawArCount) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Message", field: "arcount") + } + // arCount not used for now, but we need to read past it + _ = UInt16(bigEndian: rawArCount) + offset = newOffset + + // Read questions + self.questions = [] + for _ in 0.. Data { + // Calculate exact buffer size. + var bufferSize = Self.headerSize + for question in questions { + // name + type + class + let n = question.name.hasSuffix(".") ? String(question.name.dropLast()) : question.name + bufferSize += (try DNSName(labels: n.isEmpty ? [] : n.split(separator: ".", omittingEmptySubsequences: false).map(String.init))).size + 4 + } + for answer in answers { + // name + type + class + ttl + rdlen + rdata + let n = answer.name.hasSuffix(".") ? String(answer.name.dropLast()) : answer.name + let rdataSize = answer.type == .host ? 4 : 16 + bufferSize += (try DNSName(labels: n.isEmpty ? [] : n.split(separator: ".", omittingEmptySubsequences: false).map(String.init))).size + 10 + rdataSize + } + + var buffer = [UInt8](repeating: 0, count: bufferSize) + var offset = 0 + + // Write ID + guard let newOffset = buffer.copyIn(as: UInt16.self, value: id.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "id") + } + offset = newOffset + + // Build and write flags + var flags: UInt16 = 0 + flags |= type == .response ? 0x8000 : 0 + flags |= UInt16(operationCode.rawValue) << 11 + flags |= authoritativeAnswer ? 0x0400 : 0 + flags |= truncation ? 0x0200 : 0 + flags |= recursionDesired ? 0x0100 : 0 + flags |= recursionAvailable ? 0x0080 : 0 + flags |= UInt16(returnCode.rawValue) + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: flags.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "flags") + } + offset = newOffset + + // Write counts + guard questions.count <= UInt16.max else { + throw DNSBindError.marshalFailure(type: "Message", field: "qdcount") + } + guard answers.count <= UInt16.max else { + throw DNSBindError.marshalFailure(type: "Message", field: "ancount") + } + guard authorities.count <= UInt16.max else { + throw DNSBindError.marshalFailure(type: "Message", field: "nscount") + } + guard additional.count <= UInt16.max else { + throw DNSBindError.marshalFailure(type: "Message", field: "arcount") + } + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(questions.count).bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "qdcount") + } + offset = newOffset + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(answers.count).bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "ancount") + } + offset = newOffset + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(authorities.count).bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "nscount") + } + offset = newOffset + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(additional.count).bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "arcount") + } + offset = newOffset + + // Write questions + for question in questions { + offset = try question.appendBuffer(&buffer, offset: offset) + } + + // Write answers + for answer in answers { + offset = try answer.appendBuffer(&buffer, offset: offset) + } + + // Write authorities + for authority in authorities { + offset = try authority.appendBuffer(&buffer, offset: offset) + } + + // Write additional + for record in additional { + offset = try record.appendBuffer(&buffer, offset: offset) + } + + guard offset == bufferSize else { + throw DNSBindError.unexpectedOffset(type: "Message", expected: bufferSize, actual: offset) + } + return Data(buffer[0.. Int { + let startOffset = offset + var offset = offset + + // Write name + let normalized = name.hasSuffix(".") ? String(name.dropLast()) : name + let dnsName = try DNSName(labels: normalized.isEmpty ? [] : normalized.split(separator: ".", omittingEmptySubsequences: false).map(String.init)) + offset = try dnsName.appendBuffer(&buffer, offset: offset) + + // Write type (big-endian) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: type.rawValue.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Question", field: "type") + } + offset = newOffset + + // Write class (big-endian) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: recordClass.rawValue.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Question", field: "class") + } + + let expectedOffset = startOffset + dnsName.size + 4 + guard newOffset == expectedOffset else { + throw DNSBindError.unexpectedOffset(type: "Question", expected: expectedOffset, actual: newOffset) + } + return newOffset + } + + /// Deserialize a question from the buffer. + public mutating func bindBuffer(_ buffer: inout [UInt8], offset: Int, messageStart: Int = 0) throws -> Int { + var offset = offset + + // Read name + var dnsName = DNSName() + offset = try dnsName.bindBuffer(&buffer, offset: offset, messageStart: messageStart) + self.name = dnsName.description + + // Read type (big-endian) + guard let (newOffset, rawType) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Question", field: "type") + } + guard let qtype = ResourceRecordType(rawValue: UInt16(bigEndian: rawType)) else { + throw DNSBindError.unsupportedValue(type: "Question", field: "type") + } + self.type = qtype + offset = newOffset + + // Read class (big-endian) + guard let (newOffset, rawClass) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Question", field: "class") + } + guard let qclass = ResourceRecordClass(rawValue: UInt16(bigEndian: rawClass)) else { + throw DNSBindError.unsupportedValue(type: "Question", field: "class") + } + self.recordClass = qclass + + return newOffset + } +} diff --git a/Sources/DNSServer/Records/ResourceRecord.swift b/Sources/DNSServer/Records/ResourceRecord.swift new file mode 100644 index 000000000..37df99a6c --- /dev/null +++ b/Sources/DNSServer/Records/ResourceRecord.swift @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import Foundation + +/// Protocol for DNS resource records. +public protocol ResourceRecord: Sendable { + /// The domain name this record applies to. + var name: String { get } + + /// The record type. + var type: ResourceRecordType { get } + + /// The record class. + var recordClass: ResourceRecordClass { get } + + /// Time to live in seconds. + var ttl: UInt32 { get } + + /// Serialize this record into the buffer. + func appendBuffer(_ buffer: inout [UInt8], offset: Int) throws -> Int +} + +/// A host record (A or AAAA) containing an IP address. +public struct HostRecord: ResourceRecord { + public let name: String + public let type: ResourceRecordType + public let recordClass: ResourceRecordClass + public let ttl: UInt32 + public let ip: T + + public init( + name: String, + ttl: UInt32 = 300, + ip: T, + recordClass: ResourceRecordClass = .internet + ) { + self.name = name + self.type = T.recordType + self.recordClass = recordClass + self.ttl = ttl + self.ip = ip + } + + public func appendBuffer(_ buffer: inout [UInt8], offset: Int) throws -> Int { + let startOffset = offset + var offset = offset + + // Write name + let normalized = name.hasSuffix(".") ? String(name.dropLast()) : name + let dnsName = try DNSName(labels: normalized.isEmpty ? [] : normalized.split(separator: ".", omittingEmptySubsequences: false).map(String.init)) + offset = try dnsName.appendBuffer(&buffer, offset: offset) + + // Write type (big-endian) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: type.rawValue.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "type") + } + offset = newOffset + + // Write class (big-endian) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: recordClass.rawValue.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "class") + } + offset = newOffset + + // Write TTL (big-endian) + guard let newOffset = buffer.copyIn(as: UInt32.self, value: ttl.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "ttl") + } + offset = newOffset + + // Write rdlength (big-endian) + let rdlength = UInt16(T.size) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: rdlength.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "rdlength") + } + offset = newOffset + + // Write IP address bytes + guard let newOffset = buffer.copyIn(buffer: ip.bytes, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "rdata") + } + + let expectedOffset = startOffset + dnsName.size + 10 + T.size + guard newOffset == expectedOffset else { + throw DNSBindError.unexpectedOffset(type: "HostRecord", expected: expectedOffset, actual: newOffset) + } + return newOffset + } +} diff --git a/Sources/DNSServer/Records/UInt8+Binding.swift b/Sources/DNSServer/Records/UInt8+Binding.swift new file mode 100644 index 000000000..cd9b91b1a --- /dev/null +++ b/Sources/DNSServer/Records/UInt8+Binding.swift @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import Foundation + +// TODO: This copies some of the Bindable code from Containerization, +// but we can't use Bindable as it presumes a fixed length record. +// We can look at refining this later to see if we can use some common +// bit fiddling code everywhere. + +extension [UInt8] { + /// Copy a value into the buffer at the given offset. + /// - Returns: The new offset after writing, or nil if the buffer is too small. + package mutating func copyIn(as type: T.Type, value: T, offset: Int = 0) -> Int? { + let size = MemoryLayout.size + guard self.count >= size + offset else { + return nil + } + return self.withUnsafeMutableBytes { + $0.baseAddress?.advanced(by: offset).assumingMemoryBound(to: T.self).pointee = value + return offset + size + } + } + + /// Copy a value out of the buffer at the given offset. + /// - Returns: A tuple of (new offset, value), or nil if the buffer is too small. + package func copyOut(as type: T.Type, offset: Int = 0) -> (Int, T)? { + let size = MemoryLayout.size + guard self.count >= size + offset else { + return nil + } + return self.withUnsafeBytes { + guard let value = $0.baseAddress?.advanced(by: offset).assumingMemoryBound(to: T.self).pointee else { + return nil + } + return (offset + size, value) + } + } + + /// Copy a byte array into the buffer at the given offset. + /// - Returns: The new offset after writing, or nil if the buffer is too small. + package mutating func copyIn(buffer: [UInt8], offset: Int = 0) -> Int? { + guard offset + buffer.count <= self.count else { + return nil + } + self[offset.. Int? { + guard offset + buffer.count <= self.count else { + return nil + } + buffer[0.. Message? { - let question = query.questions[0] + guard let question = query.questions.first else { + return nil + } let record: ResourceRecord? switch question.type { case ResourceRecordType.host: @@ -50,28 +52,11 @@ struct ContainerDNSHandler: DNSHandler { ) } record = result.record - case ResourceRecordType.nameServer, - ResourceRecordType.alias, - ResourceRecordType.startOfAuthority, - ResourceRecordType.pointer, - ResourceRecordType.mailExchange, - ResourceRecordType.text, - ResourceRecordType.service, - ResourceRecordType.incrementalZoneTransfer, - ResourceRecordType.standardZoneTransfer, - ResourceRecordType.all: - return Message( - id: query.id, - type: .response, - returnCode: .notImplemented, - questions: query.questions, - answers: [] - ) default: return Message( id: query.id, type: .response, - returnCode: .formatError, + returnCode: .notImplemented, questions: query.questions, answers: [] ) @@ -95,11 +80,11 @@ struct ContainerDNSHandler: DNSHandler { return nil } let ipv4 = ipAllocation.ipv4Address.address.description - guard let ip = IPv4(ipv4) else { + guard let ip = try? IPv4Address(ipv4) else { throw DNSResolverError.serverError("failed to parse IP address: \(ipv4)") } - return HostRecord(name: question.name, ttl: ttl, ip: ip) + return HostRecord(name: question.name, ttl: ttl, ip: ip) } private func answerHost6(question: Question) async throws -> (record: ResourceRecord?, hostnameExists: Bool) { @@ -110,10 +95,10 @@ struct ContainerDNSHandler: DNSHandler { return (nil, true) } let ipv6 = ipv6Address.address.description - guard let ip = IPv6(ipv6) else { + guard let ip = try? IPv6Address(ipv6) else { throw DNSResolverError.serverError("failed to parse IPv6 address: \(ipv6)") } - return (HostRecord(name: question.name, ttl: ttl, ip: ip), true) + return (HostRecord(name: question.name, ttl: ttl, ip: ip), true) } } diff --git a/Sources/Helpers/APIServer/LocalhostDNSHandler.swift b/Sources/Helpers/APIServer/LocalhostDNSHandler.swift index cf87badb4..4dfa72166 100644 --- a/Sources/Helpers/APIServer/LocalhostDNSHandler.swift +++ b/Sources/Helpers/APIServer/LocalhostDNSHandler.swift @@ -18,7 +18,7 @@ import ContainerAPIClient import ContainerOS import ContainerPersistence import ContainerizationError -import DNS +import ContainerizationExtras import DNSServer import Foundation import Logging @@ -28,44 +28,53 @@ actor LocalhostDNSHandler: DNSHandler { private let ttl: UInt32 private let watcher: DirectoryWatcher - private let dns: Mutex<[String: IPv4]> + private var dns: [DNSName: IPv4Address] public init(resolversURL: URL = HostDNSResolver.defaultConfigPath, ttl: UInt32 = 5, log: Logger) { self.ttl = ttl self.watcher = DirectoryWatcher(directoryURL: resolversURL, log: log) - self.dns = Mutex([:]) + self.dns = [DNSName: IPv4Address]() } public func monitorResolvers() async { - await self.watcher.startWatching { fileURLs in - var dns: [String: String] = [:] + await self.watcher.startWatching { [weak self] fileURLs in + var dns: [DNSName: IPv4Address] = [:] let regex = try Regex(HostDNSResolver.localhostOptionsRegex) for file in fileURLs.filter({ $0.lastPathComponent.starts(with: HostDNSResolver.containerizationPrefix) }) { let content = try String(contentsOf: file, encoding: .utf8) if let match = content.firstMatch(of: regex), - let ipv4 = (match[1].substring.map { String($0) }) + let ipv4 = (match[1].substring.flatMap { try? IPv4Address(String($0)) }) { let name = String(file.lastPathComponent.dropFirst(HostDNSResolver.containerizationPrefix.count)) - dns[name + "."] = ipv4 + guard let dnsName = try? DNSName(name) else { + continue + } + dns[dnsName] = ipv4 } } - self.dns.withLock { $0 = dns.compactMapValues { IPv4($0) } } + Task { await self?.updateDNS(dns) } } } - nonisolated public func answer(query: Message) async throws -> Message? { - let question = query.questions[0] + public func answer(query: Message) async throws -> Message? { + guard let question = query.questions.first else { + return nil + } + let n = question.name.hasSuffix(".") ? String(question.name.dropLast()) : question.name + let key = try DNSName(labels: n.isEmpty ? [] : n.split(separator: ".", omittingEmptySubsequences: false).map(String.init)) var record: ResourceRecord? switch question.type { case ResourceRecordType.host: - let dns = dns.withLock { $0 } - if let ip = dns[question.name] { - record = HostRecord(name: question.name, ttl: ttl, ip: ip) + if let ip = dns[key] { + record = HostRecord(name: question.name, ttl: ttl, ip: ip) } case ResourceRecordType.host6: + guard dns[key] != nil else { + return nil + } return Message( id: query.id, type: .response, @@ -73,28 +82,11 @@ actor LocalhostDNSHandler: DNSHandler { questions: query.questions, answers: [] ) - case ResourceRecordType.nameServer, - ResourceRecordType.alias, - ResourceRecordType.startOfAuthority, - ResourceRecordType.pointer, - ResourceRecordType.mailExchange, - ResourceRecordType.text, - ResourceRecordType.service, - ResourceRecordType.incrementalZoneTransfer, - ResourceRecordType.standardZoneTransfer, - ResourceRecordType.all: - return Message( - id: query.id, - type: .response, - returnCode: .notImplemented, - questions: query.questions, - answers: [] - ) default: return Message( id: query.id, type: .response, - returnCode: .formatError, + returnCode: .notImplemented, questions: query.questions, answers: [] ) @@ -112,4 +104,8 @@ actor LocalhostDNSHandler: DNSHandler { answers: [record] ) } + + private func updateDNS(_ dns: [DNSName: IPv4Address]) { + self.dns = dns + } } diff --git a/Sources/Services/ContainerAPIService/Client/HostDNSResolver.swift b/Sources/Services/ContainerAPIService/Client/HostDNSResolver.swift index c5974db56..5e5fb3b74 100644 --- a/Sources/Services/ContainerAPIService/Client/HostDNSResolver.swift +++ b/Sources/Services/ContainerAPIService/Client/HostDNSResolver.swift @@ -51,10 +51,10 @@ public struct HostDNSResolver { let dnsPort = localhost == nil ? "2053" : "1053" let options = - localhost == nil - ? "" - : HostDNSResolver.localhostOptionsRegex.replacingOccurrences( - of: #"\((.*?)\)"#, with: localhost!.description, options: .regularExpression) + localhost.map { + HostDNSResolver.localhostOptionsRegex.replacingOccurrences( + of: #"\((.*?)\)"#, with: $0.description, options: .regularExpression) + } ?? "" let resolverText = """ domain \(name) search \(name) diff --git a/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift b/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift index 7966f77d2..4770078e8 100644 --- a/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift +++ b/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift @@ -339,6 +339,8 @@ public actor NetworksService { } /// Perform a hostname lookup on all networks. + /// + /// - Parameter hostname: A canonical DNS hostname with a trailing dot (e.g. `"example.com."`). public func lookup(hostname: String) async throws -> Attachment? { try await self.stateLock.withLock { _ in for state in await self.serviceStates.values { diff --git a/Tests/DNSServerTests/CompositeResolverTest.swift b/Tests/DNSServerTests/CompositeResolverTest.swift index 227c6b997..df9a4fe1b 100644 --- a/Tests/DNSServerTests/CompositeResolverTest.swift +++ b/Tests/DNSServerTests/CompositeResolverTest.swift @@ -14,7 +14,7 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras import Testing @testable import DNSServer @@ -29,35 +29,35 @@ struct CompositeResolverTest { id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let fooResponse = try await resolver.answer(query: fooQuery) #expect(.noError == fooResponse?.returnCode) #expect(1 == fooResponse?.id) #expect(1 == fooResponse?.answers.count) - let fooAnswer = fooResponse?.answers[0] as? HostRecord - #expect(IPv4("1.2.3.4") == fooAnswer?.ip) + let fooAnswer = fooResponse?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == fooAnswer?.ip) let barQuery = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "bar", type: .host) + Question(name: "bar.", type: .host) ]) let barResponse = try await resolver.answer(query: barQuery) #expect(.noError == barResponse?.returnCode) #expect(1 == barResponse?.id) #expect(1 == barResponse?.answers.count) - let barAnswer = barResponse?.answers[0] as? HostRecord - #expect(IPv4("5.6.7.8") == barAnswer?.ip) + let barAnswer = barResponse?.answers[0] as? HostRecord + #expect(try IPv4Address("5.6.7.8") == barAnswer?.ip) let otherQuery = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "other", type: .host) + Question(name: "other.", type: .host) ]) let otherResponse = try await resolver.answer(query: otherQuery) diff --git a/Tests/DNSServerTests/HostTableResolverTest.swift b/Tests/DNSServerTests/HostTableResolverTest.swift index 1a7aff2cc..96f980d17 100644 --- a/Tests/DNSServerTests/HostTableResolverTest.swift +++ b/Tests/DNSServerTests/HostTableResolverTest.swift @@ -14,23 +14,32 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras import Testing @testable import DNSServer struct HostTableResolverTest { + @Test func testEmptyQuestionsReturnsNil() async throws { + let ip = try IPv4Address("1.2.3.4") + let handler = try HostTableResolver(hosts4: ["foo.": ip]) + + let query = Message(id: UInt16(1), type: .query, questions: []) + + let response = try await handler.answer(query: query) + + #expect(nil == response) + } + @Test func testUnsupportedQuestionType() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } - let handler = HostTableResolver(hosts4: ["foo": ip]) + let ip = try IPv4Address("1.2.3.4") + let handler = try HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .mailExchange) + Question(name: "foo.", type: .mailExchange) ]) let response = try await handler.answer(query: query) @@ -43,16 +52,14 @@ struct HostTableResolverTest { } @Test func testAAAAQueryReturnsNoDataWhenARecordExists() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } - let handler = HostTableResolver(hosts4: ["foo": ip]) + let ip = try IPv4Address("1.2.3.4") + let handler = try HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .host6) + Question(name: "foo.", type: .host6) ]) let response = try await handler.answer(query: query) @@ -67,16 +74,14 @@ struct HostTableResolverTest { } @Test func testAAAAQueryReturnsNilWhenHostDoesNotExist() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } - let handler = HostTableResolver(hosts4: ["foo": ip]) + let ip = try IPv4Address("1.2.3.4") + let handler = try HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "bar", type: .host6) + Question(name: "bar.", type: .host6) ]) let response = try await handler.answer(query: query) @@ -86,16 +91,14 @@ struct HostTableResolverTest { } @Test func testHostNotPresent() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } - let handler = HostTableResolver(hosts4: ["foo": ip]) + let ip = try IPv4Address("1.2.3.4") + let handler = try HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "bar", type: .host) + Question(name: "bar.", type: .host) ]) let response = try await handler.answer(query: query) @@ -104,16 +107,62 @@ struct HostTableResolverTest { } @Test func testHostPresent() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } - let handler = HostTableResolver(hosts4: ["foo": ip]) + let ip = try IPv4Address("1.2.3.4") + let handler = try HostTableResolver(hosts4: ["foo.": ip]) + + let query = Message( + id: UInt16(1), + type: .query, + questions: [ + Question(name: "foo.", type: .host) + ]) + + let response = try await handler.answer(query: query) + + #expect(.noError == response?.returnCode) + #expect(1 == response?.id) + #expect(.response == response?.type) + #expect(1 == response?.questions.count) + #expect("foo." == response?.questions[0].name) + #expect(.host == response?.questions[0].type) + #expect(1 == response?.answers.count) + let answer = response?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == answer?.ip) + } + + @Test func testHostPresentUppercaseTable() async throws { + let ip = try IPv4Address("1.2.3.4") + let handler = try HostTableResolver(hosts4: ["FOO.": ip]) + + let query = Message( + id: UInt16(1), + type: .query, + questions: [ + Question(name: "foo.", type: .host) + ]) + + let response = try await handler.answer(query: query) + + #expect(.noError == response?.returnCode) + #expect(1 == response?.id) + #expect(.response == response?.type) + #expect(1 == response?.questions.count) + #expect("foo." == response?.questions[0].name) + #expect(.host == response?.questions[0].type) + #expect(1 == response?.answers.count) + let answer = response?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == answer?.ip) + } + + @Test func testHostPresentUppercaseQuestion() async throws { + let ip = try IPv4Address("1.2.3.4") + let handler = try HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .host) + Question(name: "FOO.", type: .host) ]) let response = try await handler.answer(query: query) @@ -122,10 +171,10 @@ struct HostTableResolverTest { #expect(1 == response?.id) #expect(.response == response?.type) #expect(1 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("FOO." == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(1 == response?.answers.count) - let answer = response?.answers[0] as? HostRecord - #expect(IPv4("1.2.3.4") == answer?.ip) + let answer = response?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == answer?.ip) } } diff --git a/Tests/DNSServerTests/MockHandlers.swift b/Tests/DNSServerTests/MockHandlers.swift index 6496c7e04..b0a4740cb 100644 --- a/Tests/DNSServerTests/MockHandlers.swift +++ b/Tests/DNSServerTests/MockHandlers.swift @@ -14,23 +14,21 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras import Testing @testable import DNSServer struct FooHandler: DNSHandler { public func answer(query: Message) async throws -> Message? { - if query.questions[0].name == "foo" { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + if query.questions[0].name == "foo." { + let ip = try IPv4Address("1.2.3.4") return Message( id: query.id, type: .response, returnCode: .noError, questions: query.questions, - answers: [HostRecord(name: query.questions[0].name, ttl: 0, ip: ip)] + answers: [HostRecord(name: query.questions[0].name, ttl: 0, ip: ip)] ) } return nil @@ -40,16 +38,14 @@ struct FooHandler: DNSHandler { struct BarHandler: DNSHandler { public func answer(query: Message) async throws -> Message? { let question = query.questions[0] - if question.name == "foo" || question.name == "bar" { - guard let ip = IPv4("5.6.7.8") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + if question.name == "foo." || question.name == "bar." { + let ip = try IPv4Address("5.6.7.8") return Message( id: query.id, type: .response, returnCode: .noError, questions: query.questions, - answers: [HostRecord(name: query.questions[0].name, ttl: 0, ip: ip)] + answers: [HostRecord(name: query.questions[0].name, ttl: 0, ip: ip)] ) } return nil diff --git a/Tests/DNSServerTests/NxDomainResolverTest.swift b/Tests/DNSServerTests/NxDomainResolverTest.swift index db592e56d..27264c158 100644 --- a/Tests/DNSServerTests/NxDomainResolverTest.swift +++ b/Tests/DNSServerTests/NxDomainResolverTest.swift @@ -14,7 +14,6 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS import Testing @testable import DNSServer @@ -27,7 +26,7 @@ struct NxDomainResolverTest { id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .host6) + Question(name: "foo.", type: .host6) ]) let response = try await handler.answer(query: query) @@ -46,7 +45,7 @@ struct NxDomainResolverTest { id: UInt16(1), type: .query, questions: [ - Question(name: "bar", type: .host) + Question(name: "bar.", type: .host) ]) let response = try await handler.answer(query: query) diff --git a/Tests/DNSServerTests/RecordsTests.swift b/Tests/DNSServerTests/RecordsTests.swift new file mode 100644 index 000000000..71e7726da --- /dev/null +++ b/Tests/DNSServerTests/RecordsTests.swift @@ -0,0 +1,673 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import ContainerizationExtras +import Foundation +import Testing + +@testable import DNSServer + +@Suite("DNS Records Tests") +struct RecordsTests { + + // MARK: - DNSName Tests + + @Suite("DNSName") + struct DNSNameTests { + @Test("Create from string") + func createFromString() throws { + let name = try DNSName("example.com") + #expect(name.labels == ["example", "com"]) + } + + @Test("Create from string with trailing dot") + func createFromStringTrailingDot() throws { + let name = try DNSName("example.com.") + #expect(name.labels == ["example", "com"]) + } + + @Test("Description includes trailing dot") + func descriptionTrailingDot() throws { + let name = try DNSName("example.com") + #expect(name.description == "example.com.") + } + + @Test("Root domain") + func rootDomain() throws { + let name = try DNSName("") + #expect(name.labels == []) + #expect(name.description == ".") + } + + @Test("Size calculation") + func sizeCalculation() throws { + let name = try DNSName("example.com") + // [7]example[3]com[0] = 1 + 7 + 1 + 3 + 1 = 13 + #expect(name.size == 13) + } + + @Test("Serialize and deserialize") + func serializeDeserialize() throws { + let original = try DNSName("test.example.com") + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try original.appendBuffer(&buffer, offset: 0) + + var parsed = DNSName() + let readOffset = try parsed.bindBuffer(&buffer, offset: 0) + + // [4]test[7]example[3]com[0] = 5+8+4+1 = 18 + #expect(endOffset == 18) + #expect(readOffset == endOffset) + #expect(parsed.labels == original.labels) + } + + @Test("Serialize subdomain") + func serializeSubdomain() throws { + let name = try DNSName("a.b.c.d.example.com") + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try name.appendBuffer(&buffer, offset: 0) + + var parsed = DNSName() + let readOffset = try parsed.bindBuffer(&buffer, offset: 0) + + // [1]a[1]b[1]c[1]d[7]example[3]com[0] = 2+2+2+2+8+4+1 = 21 + #expect(endOffset == 21) + #expect(readOffset == endOffset) + #expect(parsed.labels == ["a", "b", "c", "d", "example", "com"]) + } + + @Test("Reject label too long") + func rejectLabelTooLong() { + let longLabel = String(repeating: "a", count: 64) + #expect(throws: DNSBindError.self) { + _ = try DNSName(longLabel + ".com") + } + } + + @Test("Reject embedded carriage return") + func rejectEmbeddedCarriageReturn() { + #expect(throws: DNSBindError.self) { + _ = try DNSName("foo\r.com") + } + } + + @Test("Reject embedded newline") + func rejectEmbeddedNewline() { + #expect(throws: DNSBindError.self) { + _ = try DNSName("foo\n.com") + } + } + + @Test("Reject embedded null byte") + func rejectEmbeddedNullByte() { + #expect(throws: DNSBindError.self) { + _ = try DNSName("foo\0.com") + } + } + + @Test("Reject empty label") + func rejectEmptyLabel() { + #expect(throws: DNSBindError.self) { + _ = try DNSName("foo..com") + } + } + + @Test("Reject name too long") + func rejectNameTooLong() { + // 9 labels * (1 + 30) bytes + 1 null = 280 bytes > 255 + let label = String(repeating: "a", count: 30) + let name = Array(repeating: label, count: 9).joined(separator: ".") + #expect(throws: DNSBindError.self) { + _ = try DNSName(name) + } + } + + @Test("Reject leading hyphen") + func rejectLeadingHyphen() { + #expect(throws: DNSBindError.self) { + _ = try DNSName("-foo.com") + } + } + + @Test("Reject trailing hyphen") + func rejectTrailingHyphen() { + #expect(throws: DNSBindError.self) { + _ = try DNSName("foo-.com") + } + } + + @Test("Reject leading underscore") + func rejectLeadingUnderscore() { + #expect(throws: DNSBindError.self) { + _ = try DNSName("_foo.com") + } + } + + @Test("Reject trailing underscore") + func rejectTrailingUnderscore() { + #expect(throws: DNSBindError.self) { + _ = try DNSName("foo_.com") + } + } + + @Test("Accept service labels via init(labels:)") + func acceptServiceLabels() throws { + let name = try DNSName(labels: ["_dns-sd", "_udp", "local"]) + #expect(name.labels == ["_dns-sd", "_udp", "local"]) + } + + @Test("Lowercase labels on init") + func lowercaseLabelsOnInit() throws { + let name = try DNSName("EXAMPLE.COM") + #expect(name.labels == ["example", "com"]) + } + + @Test("Lowercase labels on init with trailing dot") + func lowercaseLabelsOnInitTrailingDot() throws { + let name = try DNSName("Example.Com.") + #expect(name.labels == ["example", "com"]) + } + + @Test("Lowercase labels from wire format") + func lowercaseLabelsFromWire() throws { + // Wire-encode "EXAMPLE.COM" with uppercase bytes, then decode + let upper = try DNSName(labels: ["EXAMPLE", "COM"]) + var buffer = [UInt8](repeating: 0, count: 64) + let endOffset = try upper.appendBuffer(&buffer, offset: 0) + + var parsed = DNSName() + let readOffset = try parsed.bindBuffer(&buffer, offset: 0) + + // [7]example[3]com[0] = 8+4+1 = 13 + #expect(endOffset == 13) + #expect(readOffset == endOffset) + #expect(parsed.labels == ["example", "com"]) + } + + @Test("Follow valid compression pointer") + func followCompressionPointer() throws { + // Build a buffer with two names: + // offset 0: "example.com." — [7]example[3]com[0] (13 bytes) + // offset 13: "test." — [4]test 0xC0 0x00 ( 7 bytes) + // The pointer 0xC0 0x00 points back to offset 0. + var buffer: [UInt8] = [ + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // [7]example + 0x03, 0x63, 0x6f, 0x6d, // [3]com + 0x00, // null terminator + 0x04, 0x74, 0x65, 0x73, 0x74, // [4]test + 0xC0, 0x00, // pointer to offset 0 + ] + + var name = DNSName() + let readOffset = try name.bindBuffer(&buffer, offset: 13) + + // Pointer bytes are at offset 18–19; returnOffset = 18 + 2 = 20 + #expect(readOffset == 20) + #expect(name.labels == ["test", "example", "com"]) + } + + @Test("Reject forward compression pointer") + func rejectForwardCompressionPointer() throws { + // Craft a packet with a forward compression pointer at offset 12 pointing to offset 20 + // Header (12 bytes) + pointer bytes + var buffer = [UInt8](repeating: 0, count: 32) + // At offset 0: compression pointer to offset 20 (forward) + buffer[0] = 0xC0 + buffer[1] = 0x14 // points to offset 20, which is > 0 + + #expect(throws: DNSBindError.self) { + var b = buffer + var name = DNSName() + _ = try name.bindBuffer(&b, offset: 0) + } + } + + @Test("Reject self-referential compression pointer") + func rejectSelfReferentialCompressionPointer() throws { + var buffer = [UInt8](repeating: 0, count: 16) + // At offset 0: compression pointer pointing back to offset 0 (same location) + buffer[0] = 0xC0 + buffer[1] = 0x00 // points to offset 0 == current offset, not prior + + #expect(throws: DNSBindError.self) { + var b = buffer + var name = DNSName() + _ = try name.bindBuffer(&b, offset: 0) + } + } + + @Test("Reject compression pointer hop limit exceeded") + func rejectCompressionPointerHopLimit() throws { + // Build a chain of backward pointers: + // offset 0: [1]a[0] — terminal name (3 bytes) + // offset 3: 0xC0 0x00 — pointer → 0 + // offset 5: 0xC0 0x03 — pointer → 3 + // ...each entry points to the one before it... + // offset 23: 0xC0 0x15 — pointer → 21 + // offset 25: 0xC0 0x17 — pointer → 23 + // + // Reading from offset 25 follows 11 hops (25→23→21→...→3→0), + // which exceeds the limit of 10. + var buffer: [UInt8] = [ + 0x01, 0x61, 0x00, // offset 0: [1]a[0] + 0xC0, 0x00, // offset 3: → 0 + 0xC0, 0x03, // offset 5: → 3 + 0xC0, 0x05, // offset 7: → 5 + 0xC0, 0x07, // offset 9: → 7 + 0xC0, 0x09, // offset 11: → 9 + 0xC0, 0x0B, // offset 13: → 11 + 0xC0, 0x0D, // offset 15: → 13 + 0xC0, 0x0F, // offset 17: → 15 + 0xC0, 0x11, // offset 19: → 17 + 0xC0, 0x13, // offset 21: → 19 + 0xC0, 0x15, // offset 23: → 21 + 0xC0, 0x17, // offset 25: → 23 + ] + + #expect(throws: DNSBindError.self) { + var name = DNSName() + _ = try name.bindBuffer(&buffer, offset: 25) + } + } + } + + // MARK: - Question Tests + + @Suite("Question") + struct QuestionTests { + @Test("Create question") + func create() { + let q = Question(name: "example.com.", type: .host, recordClass: .internet) + #expect(q.name == "example.com.") + #expect(q.type == .host) + #expect(q.recordClass == .internet) + } + + @Test("Serialize and deserialize A record question") + func serializeDeserializeA() throws { + let original = Question(name: "example.com.", type: .host, recordClass: .internet) + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try original.appendBuffer(&buffer, offset: 0) + + var parsed = Question(name: "") + let readOffset = try parsed.bindBuffer(&buffer, offset: 0) + + // name([7]example[3]com[0]=13) + type(2) + class(2) = 17 + #expect(endOffset == 17) + #expect(readOffset == endOffset) + #expect(parsed.type == .host) + #expect(parsed.recordClass == .internet) + } + + @Test("Serialize and deserialize AAAA record question") + func serializeDeserializeAAAA() throws { + let original = Question(name: "example.com.", type: .host6, recordClass: .internet) + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try original.appendBuffer(&buffer, offset: 0) + + var parsed = Question(name: "") + let readOffset = try parsed.bindBuffer(&buffer, offset: 0) + + // name([7]example[3]com[0]=13) + type(2) + class(2) = 17 + #expect(endOffset == 17) + #expect(readOffset == endOffset) + #expect(parsed.type == .host6) + } + } + + // MARK: - HostRecord Tests + + @Suite("HostRecord") + struct HostRecordTests { + @Test("Create A record") + func createARecord() throws { + let ip = try IPv4Address("192.168.1.1") + let record = HostRecord(name: "example.com.", ttl: 300, ip: ip) + + #expect(record.name == "example.com.") + #expect(record.type == .host) + #expect(record.ttl == 300) + #expect(record.ip == ip) + } + + @Test("Create AAAA record") + func createAAAARecord() throws { + let ip = try IPv6Address("::1") + let record = HostRecord(name: "example.com.", ttl: 600, ip: ip) + + #expect(record.name == "example.com.") + #expect(record.type == .host6) + #expect(record.ttl == 600) + } + + @Test("Serialize A record") + func serializeARecord() throws { + let ip = try IPv4Address("10.0.0.1") + let record = HostRecord(name: "test.com.", ttl: 300, ip: ip) + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try record.appendBuffer(&buffer, offset: 0) + + // name([4]test[3]com[0]=10) + type(2) + class(2) + ttl(4) + rdlen(2) + rdata(4) = 24 + #expect(endOffset == 24) + + // Verify IP bytes at the end + #expect(buffer[endOffset - 4] == 10) + #expect(buffer[endOffset - 3] == 0) + #expect(buffer[endOffset - 2] == 0) + #expect(buffer[endOffset - 1] == 1) + } + + @Test("Serialize AAAA record") + func serializeAAAARecord() throws { + let ip = try IPv6Address("::1") + let record = HostRecord(name: "test.com.", ttl: 300, ip: ip) + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try record.appendBuffer(&buffer, offset: 0) + + // name([4]test[3]com[0]=10) + type(2) + class(2) + ttl(4) + rdlen(2) + rdata(16) = 36 + #expect(endOffset == 36) + #expect(buffer[endOffset - 1] == 1) + } + } + + // MARK: - Message Tests + + @Suite("Message") + struct MessageTests { + @Test("Create query message") + func createQuery() { + let msg = Message( + id: 0x1234, + type: .query, + questions: [Question(name: "example.com.", type: .host)] + ) + + #expect(msg.id == 0x1234) + #expect(msg.type == .query) + #expect(msg.questions.count == 1) + } + + @Test("Create response message") + func createResponse() throws { + let ip = try IPv4Address("192.168.1.1") + let msg = Message( + id: 0x1234, + type: .response, + returnCode: .noError, + questions: [Question(name: "example.com.", type: .host)], + answers: [HostRecord(name: "example.com.", ttl: 300, ip: ip)] + ) + + #expect(msg.type == .response) + #expect(msg.returnCode == .noError) + #expect(msg.answers.count == 1) + } + + @Test("Serialize and deserialize query") + func serializeDeserializeQuery() throws { + let original = Message( + id: 0xABCD, + type: .query, + recursionDesired: true, + questions: [Question(name: "example.com.", type: .host)] + ) + + let data = try original.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.id == 0xABCD) + #expect(parsed.type == .query) + #expect(parsed.recursionDesired == true) + #expect(parsed.questions.count == 1) + #expect(parsed.questions[0].type == .host) + } + + @Test("Serialize response with answer") + func serializeResponse() throws { + let ip = try IPv4Address("10.0.0.1") + let msg = Message( + id: 0x1234, + type: .response, + authoritativeAnswer: true, + returnCode: .noError, + questions: [Question(name: "test.com.", type: .host)], + answers: [HostRecord(name: "test.com.", ttl: 300, ip: ip)] + ) + + let data = try msg.serialize() + + // Verify we can at least parse the header back + let parsed = try Message(deserialize: data) + #expect(parsed.id == 0x1234) + #expect(parsed.type == .response) + #expect(parsed.authoritativeAnswer == true) + #expect(parsed.returnCode == .noError) + } + + @Test("Serialize NXDOMAIN response") + func serializeNxdomain() throws { + let msg = Message( + id: 0x1234, + type: .response, + returnCode: .nonExistentDomain, + questions: [Question(name: "unknown.com.", type: .host)], + answers: [] + ) + + let data = try msg.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.returnCode == .nonExistentDomain) + #expect(parsed.answers.count == 0) + } + + @Test("Serialize NODATA response (empty answers with noError)") + func serializeNodata() throws { + let msg = Message( + id: 0x1234, + type: .response, + returnCode: .noError, + questions: [Question(name: "example.com.", type: .host6)], + answers: [] + ) + + let data = try msg.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.returnCode == .noError) + #expect(parsed.answers.count == 0) + } + + @Test("Multiple questions") + func multipleQuestions() throws { + let msg = Message( + id: 0x1234, + type: .query, + questions: [ + Question(name: "a.com.", type: .host), + Question(name: "b.com.", type: .host6), + ] + ) + + let data = try msg.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.questions.count == 2) + #expect(parsed.questions[0].type == .host) + #expect(parsed.questions[1].type == .host6) + } + + @Test("Reject too many questions") + func rejectTooManyQuestions() { + let questions = Array(repeating: Question(name: "a.com.", type: .host), count: Int(UInt16.max) + 1) + let msg = Message(id: 0, type: .query, questions: questions) + #expect(throws: DNSBindError.self) { + _ = try msg.serialize() + } + } + + @Test("Reject too many answers") + func rejectTooManyAnswers() throws { + let ip = try IPv4Address("1.2.3.4") + let answers = Array(repeating: HostRecord(name: "a.com.", ttl: 0, ip: ip), count: Int(UInt16.max) + 1) + let msg = Message(id: 0, type: .response, answers: answers) + #expect(throws: DNSBindError.self) { + _ = try msg.serialize() + } + } + } + + // MARK: - Wire Format Tests + + @Suite("Wire Format") + struct WireFormatTests { + @Test("Parse real DNS query bytes") + func parseRealQuery() throws { + // A minimal DNS query for "example.com" A record + // Header: ID=0x1234, QR=0, OPCODE=0, RD=1, QDCOUNT=1 + let queryBytes: [UInt8] = [ + 0x12, 0x34, // ID + 0x01, 0x00, // Flags: RD=1 + 0x00, 0x01, // QDCOUNT=1 + 0x00, 0x00, // ANCOUNT=0 + 0x00, 0x00, // NSCOUNT=0 + 0x00, 0x00, // ARCOUNT=0 + // Question: example.com A IN + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // "example" + 0x03, 0x63, 0x6f, 0x6d, // "com" + 0x00, // null terminator + 0x00, 0x01, // QTYPE=A + 0x00, 0x01, // QCLASS=IN + ] + + let msg = try Message(deserialize: Data(queryBytes)) + + #expect(msg.id == 0x1234) + #expect(msg.type == .query) + #expect(msg.recursionDesired == true) + #expect(msg.questions.count == 1) + #expect(msg.questions[0].type == .host) + #expect(msg.questions[0].recordClass == .internet) + } + + @Test("Roundtrip preserves data") + func roundtrip() throws { + let ip = try IPv4Address("1.2.3.4") + let original = Message( + id: 0xBEEF, + type: .response, + operationCode: .query, + authoritativeAnswer: true, + truncation: false, + recursionDesired: true, + recursionAvailable: true, + returnCode: .noError, + questions: [Question(name: "test.example.com.", type: .host)], + answers: [HostRecord(name: "test.example.com.", ttl: 3600, ip: ip)] + ) + + let data = try original.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.id == original.id) + #expect(parsed.type == original.type) + #expect(parsed.authoritativeAnswer == original.authoritativeAnswer) + #expect(parsed.truncation == original.truncation) + #expect(parsed.recursionDesired == original.recursionDesired) + #expect(parsed.recursionAvailable == original.recursionAvailable) + #expect(parsed.returnCode == original.returnCode) + #expect(parsed.questions.count == original.questions.count) + } + + @Test("Reject unknown opcode") + func rejectUnknownOpcode() { + // Opcode occupies bits 14–11 of the flags word. Value 3 is reserved. + // Flags: 0x18 0x00 = QR=0, OPCODE=3, all other bits clear. + let bytes: [UInt8] = [ + 0x00, 0x01, // ID + 0x18, 0x00, // Flags: OPCODE=3 (reserved) + 0x00, 0x00, // QDCOUNT=0 + 0x00, 0x00, // ANCOUNT=0 + 0x00, 0x00, // NSCOUNT=0 + 0x00, 0x00, // ARCOUNT=0 + ] + #expect(throws: DNSBindError.self) { + _ = try Message(deserialize: Data(bytes)) + } + } + + @Test("Reject unknown RCODE") + func rejectUnknownRcode() { + // RCODE occupies bits 3–0 of the flags word. Value 12 is reserved. + // Flags: 0x00 0x0C = QR=0, OPCODE=0, RCODE=12. + let bytes: [UInt8] = [ + 0x00, 0x01, // ID + 0x00, 0x0C, // Flags: RCODE=12 (reserved) + 0x00, 0x00, // QDCOUNT=0 + 0x00, 0x00, // ANCOUNT=0 + 0x00, 0x00, // NSCOUNT=0 + 0x00, 0x00, // ARCOUNT=0 + ] + #expect(throws: DNSBindError.self) { + _ = try Message(deserialize: Data(bytes)) + } + } + + @Test("Reject unknown query type") + func rejectUnknownQueryType() { + // Type 54 is unassigned in the IANA DNS parameters registry. + let bytes: [UInt8] = [ + 0x00, 0x01, // ID + 0x00, 0x00, // Flags: standard query + 0x00, 0x01, // QDCOUNT=1 + 0x00, 0x00, // ANCOUNT=0 + 0x00, 0x00, // NSCOUNT=0 + 0x00, 0x00, // ARCOUNT=0 + 0x01, 0x61, 0x00, // name: [1]a[0] + 0x00, 0x36, // QTYPE=54 (unassigned) + 0x00, 0x01, // QCLASS=IN + ] + #expect(throws: DNSBindError.self) { + _ = try Message(deserialize: Data(bytes)) + } + } + + @Test("Reject unknown record class") + func rejectUnknownRecordClass() { + // Class 2 is unassigned in the IANA DNS parameters registry. + let bytes: [UInt8] = [ + 0x00, 0x01, // ID + 0x00, 0x00, // Flags: standard query + 0x00, 0x01, // QDCOUNT=1 + 0x00, 0x00, // ANCOUNT=0 + 0x00, 0x00, // NSCOUNT=0 + 0x00, 0x00, // ARCOUNT=0 + 0x01, 0x61, 0x00, // name: [1]a[0] + 0x00, 0x01, // QTYPE=A + 0x00, 0x02, // QCLASS=2 (unassigned) + ] + #expect(throws: DNSBindError.self) { + _ = try Message(deserialize: Data(bytes)) + } + } + } +} diff --git a/Tests/DNSServerTests/StandardQueryValidatorTest.swift b/Tests/DNSServerTests/StandardQueryValidatorTest.swift index 185b00b44..868344fb9 100644 --- a/Tests/DNSServerTests/StandardQueryValidatorTest.swift +++ b/Tests/DNSServerTests/StandardQueryValidatorTest.swift @@ -14,7 +14,7 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras import Testing @testable import DNSServer @@ -28,7 +28,7 @@ struct StandardQueryValidatorTest { id: UInt16(1), type: .response, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let response = try await handler.answer(query: query) @@ -37,7 +37,7 @@ struct StandardQueryValidatorTest { #expect(1 == response?.id) #expect(.response == response?.type) #expect(1 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(0 == response?.answers.count) } @@ -51,7 +51,7 @@ struct StandardQueryValidatorTest { type: .query, operationCode: .notify, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let response = try await handler.answer(query: query) @@ -60,11 +60,25 @@ struct StandardQueryValidatorTest { #expect(2 == response?.id) #expect(.response == response?.type) #expect(1 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(0 == response?.answers.count) } + @Test func testRejectNoQuestions() async throws { + let fooHandler = FooHandler() + let handler = StandardQueryValidator(handler: fooHandler) + + let query = Message(id: UInt16(3), type: .query, questions: []) + + let response = try await handler.answer(query: query) + + #expect(.formatError == response?.returnCode) + #expect(3 == response?.id) + #expect(.response == response?.type) + #expect(0 == response?.answers.count) + } + @Test func testRejectMultipleQuestions() async throws { let fooHandler = FooHandler() let handler = StandardQueryValidator(handler: fooHandler) @@ -73,8 +87,8 @@ struct StandardQueryValidatorTest { id: UInt16(2), type: .query, questions: [ - Question(name: "foo", type: .host), - Question(name: "bar", type: .host), + Question(name: "foo.", type: .host), + Question(name: "bar.", type: .host), ]) let response = try await handler.answer(query: query) @@ -83,9 +97,9 @@ struct StandardQueryValidatorTest { #expect(2 == response?.id) #expect(.response == response?.type) #expect(2 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) - #expect("bar" == response?.questions[1].name) + #expect("bar." == response?.questions[1].name) #expect(.host == response?.questions[1].type) #expect(0 == response?.answers.count) } @@ -98,7 +112,7 @@ struct StandardQueryValidatorTest { id: UInt16(2), type: .query, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let response = try await handler.answer(query: query) @@ -107,10 +121,10 @@ struct StandardQueryValidatorTest { #expect(2 == response?.id) #expect(.response == response?.type) #expect(1 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(1 == response?.answers.count) - let answer = response?.answers[0] as? HostRecord - #expect(IPv4("1.2.3.4") == answer?.ip) + let answer = response?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == answer?.ip) } }