Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions Sources/SmartCodableMacros/SmartSubclassMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ import SwiftSyntaxMacros

/// A macro that automatically implements SmartCodable inheritance support
public struct SmartSubclassMacro: MemberMacro {
private enum SynthesizedMemberAccess {
case inheritedDefault
case publicVisible

var prefix: String {
switch self {
case .inheritedDefault:
return ""
case .publicVisible:
return "public "
}
}
}

public static func expansion(
of node: AttributeSyntax,
providingMembersOf declaration: some DeclGroupSyntax,
Expand Down Expand Up @@ -47,24 +61,25 @@ public struct SmartSubclassMacro: MemberMacro {

// 获取类的属性
let properties = try extractProperties(from: classDecl)
let memberAccess = synthesizedMemberAccess(for: classDecl)

var members: [DeclSyntax] = []

// 生成CodingKeys枚举
members.append(generateCodingKeysEnum(for: properties))

// 生成init(from:)方法
members.append(generateInitFromDecoder(for: properties))
members.append(generateInitFromDecoder(for: properties, access: memberAccess))

// 生成encode(to:)方法
members.append(generateEncodeToEncoder(for: properties))
members.append(generateEncodeToEncoder(for: properties, access: memberAccess))


if hasRequiredInitializer(classDecl) {
return members
} else {
// 生成required init()方法
members.append(generateRequiredInit())
members.append(generateRequiredInit(access: memberAccess))
return members
}
}
Expand Down Expand Up @@ -148,7 +163,10 @@ public struct SmartSubclassMacro: MemberMacro {
}

// 辅助方法:生成init(from:)方法
private static func generateInitFromDecoder(for properties: [ModelMemberProperty]) -> DeclSyntax {
private static func generateInitFromDecoder(
for properties: [ModelMemberProperty],
access: SynthesizedMemberAccess
) -> DeclSyntax {
let decodingStatements = properties.map { property in
let propertyName = property.accessName
let propertyType = property.type
Expand All @@ -163,7 +181,7 @@ public struct SmartSubclassMacro: MemberMacro {
}.joined(separator: "\n")

return """
required init(from decoder: Decoder) throws {
\(raw: access.prefix)required init(from decoder: Decoder) throws {
try super.init(from: decoder)

let container = try decoder.container(keyedBy: CodingKeys.self)
Expand All @@ -173,7 +191,10 @@ public struct SmartSubclassMacro: MemberMacro {
}

// 辅助方法:生成encode(to:)方法
private static func generateEncodeToEncoder(for properties: [ModelMemberProperty]) -> DeclSyntax {
private static func generateEncodeToEncoder(
for properties: [ModelMemberProperty],
access: SynthesizedMemberAccess
) -> DeclSyntax {
let encodingStatements = properties.map { property in
if property.type.hasSuffix("?") {
return "try container.encodeIfPresent(\(property.accessName), forKey: .\(property.codingKeyName))"
Expand All @@ -183,7 +204,7 @@ public struct SmartSubclassMacro: MemberMacro {
}.joined(separator: "\n")

return """
override func encode(to encoder: Encoder) throws {
\(raw: access.prefix)override func encode(to encoder: Encoder) throws {
try super.encode(to: encoder)

var container = encoder.container(keyedBy: CodingKeys.self)
Expand All @@ -206,11 +227,22 @@ public struct SmartSubclassMacro: MemberMacro {
}

// 辅助方法:生成required init()方法
private static func generateRequiredInit() -> DeclSyntax {
private static func generateRequiredInit(access: SynthesizedMemberAccess) -> DeclSyntax {
return """
required init() {
\(raw: access.prefix)required init() {
super.init()
}
"""
}

private static func synthesizedMemberAccess(for classDecl: ClassDeclSyntax) -> SynthesizedMemberAccess {
if classDecl.modifiers.contains(where: { modifier in
let name = modifier.name.text
return name == "public" || name == "open"
}) {
return .publicVisible
}

return .inheritedDefault
}
}
220 changes: 220 additions & 0 deletions Tests/SmartSubclassMacroAccessControlTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import XCTest
import SwiftSyntaxMacros
import SwiftSyntaxMacrosTestSupport
@testable import SmartCodableMacros

/// Tests for `@SmartSubclass` macro access control inference.
///
/// These tests verify that the macro correctly derives access modifiers for
/// synthesized members (`init(from:)`, `encode(to:)`, `init()`) based on the
/// visibility of the host class.
final class SmartSubclassMacroAccessControlTests: XCTestCase {
private let macros: [String: Macro.Type] = [
"SmartSubclass": SmartSubclassMacro.self
]

// MARK: - Helpers

/// Base model definition shared across all test cases.
private let baseModelDefinition = """
class BaseModel {
var name: String = ""

required init() {}
required init(from decoder: Decoder) throws {}
func encode(to encoder: Encoder) throws {}
}
"""

/// Asserts macro expansion with consistent base model context.
///
/// - Parameters:
/// - classDeclaration: The subclass declaration with `@SmartSubclass` attribute.
/// - expectedClassOutput: The expected expanded source code.
/// - file: Source file for failure reporting.
/// - line: Line number for failure reporting.
private func assertAccessControlMacroExpansion(
classDeclaration: String,
expectedClassOutput: String,
file: StaticString = #file,
line: UInt = #line
) {
let input = """
\(baseModelDefinition)

@SmartSubclass
\(classDeclaration)
"""

let expectedOutput = """
\(baseModelDefinition)
\(expectedClassOutput)
"""

assertMacroExpansion(
input,
expandedSource: expectedOutput,
macros: macros,
file: file,
line: line
)
}

// MARK: - Tests

/// Verifies that `public` class generates members with `public` modifiers.
func testPublicClassExpansionAddsPublicAccessModifiers() {
assertAccessControlMacroExpansion(
classDeclaration: """
public class PublicStudent: BaseModel {
var age: Int = 0
}
""",
expectedClassOutput: """
public class PublicStudent: BaseModel {
var age: Int = 0

enum CodingKeys: CodingKey {
case age
}

public required init(from decoder: Decoder) throws {
try super.init(from: decoder)

let container = try decoder.container(keyedBy: CodingKeys.self)
self.age = try container.decodeIfPresent(Int.self, forKey: .age) ?? self.age
}

public override func encode(to encoder: Encoder) throws {
try super.encode(to: encoder)

var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(age, forKey: .age)
}

public required init() {
super.init()
}
}
"""
)
}

/// Verifies that `open` class generates members with `public` modifiers (Phase 1).
func testOpenClassExpansionAddsPublicAccessModifiers() {
assertAccessControlMacroExpansion(
classDeclaration: """
open class OpenStudent: BaseModel {
var age: Int = 0
}
""",
expectedClassOutput: """
open class OpenStudent: BaseModel {
var age: Int = 0

enum CodingKeys: CodingKey {
case age
}

public required init(from decoder: Decoder) throws {
try super.init(from: decoder)

let container = try decoder.container(keyedBy: CodingKeys.self)
self.age = try container.decodeIfPresent(Int.self, forKey: .age) ?? self.age
}

public override func encode(to encoder: Encoder) throws {
try super.encode(to: encoder)

var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(age, forKey: .age)
}

public required init() {
super.init()
}
}
"""
)
}

/// Verifies that internal/default class does not add explicit `public` modifiers.
func testInternalClassDoesNotAddPublicAccessModifiers() {
assertAccessControlMacroExpansion(
classDeclaration: """
class InternalStudent: BaseModel {
var age: Int = 0
}
""",
expectedClassOutput: """
class InternalStudent: BaseModel {
var age: Int = 0

enum CodingKeys: CodingKey {
case age
}

required init(from decoder: Decoder) throws {
try super.init(from: decoder)

let container = try decoder.container(keyedBy: CodingKeys.self)
self.age = try container.decodeIfPresent(Int.self, forKey: .age) ?? self.age
}

override func encode(to encoder: Encoder) throws {
try super.encode(to: encoder)

var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(age, forKey: .age)
}

required init() {
super.init()
}
}
"""
)
}

/// Verifies that existing `required init()` prevents duplicate generation.
func testClassWithExistingRequiredInitSkipsGeneratedInit() {
assertAccessControlMacroExpansion(
classDeclaration: """
public class StudentWithInit: BaseModel {
var age: Int = 0

required init() {
super.init()
}
}
""",
expectedClassOutput: """
public class StudentWithInit: BaseModel {
var age: Int = 0

required init() {
super.init()
}

enum CodingKeys: CodingKey {
case age
}

public required init(from decoder: Decoder) throws {
try super.init(from: decoder)

let container = try decoder.container(keyedBy: CodingKeys.self)
self.age = try container.decodeIfPresent(Int.self, forKey: .age) ?? self.age
}

public override func encode(to encoder: Encoder) throws {
try super.encode(to: encoder)

var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(age, forKey: .age)
}
}
"""
)
}
}