-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Expand file tree
/
Copy pathConnection+Aggregation.swift
More file actions
155 lines (139 loc) · 6.13 KB
/
Connection+Aggregation.swift
File metadata and controls
155 lines (139 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import Foundation
#if StandaloneSQLite
import sqlite3
#elseif SQLCipher
import SQLCipher
#elseif SwiftToolchainCSQLite
import SwiftToolchainCSQLite
#else
import SQLite3 // SystemSQLite
#endif
extension Connection {
private typealias Aggregate = @convention(block) (Int, Context, Int32, Argv) -> Void
/// Creates or redefines a custom SQL aggregate.
///
/// - Parameters:
///
/// - aggregate: The name of the aggregate to create or redefine.
///
/// - argumentCount: The number of arguments that the aggregate takes. If
/// `nil`, the aggregate may take any number of arguments.
///
/// Default: `nil`
///
/// - deterministic: Whether or not the aggregate is deterministic (_i.e._
/// the aggregate always returns the same result for a given input).
///
/// Default: `false`
///
/// - step: A block of code to run for each row of an aggregation group.
/// The block is called with an array of raw SQL values mapped to the
/// aggregate’s parameters, and an UnsafeMutablePointer to a state
/// variable.
///
/// - final: A block of code to run after each row of an aggregation group
/// is processed. The block is called with an UnsafeMutablePointer to a
/// state variable, and should return a raw SQL value (or nil).
///
/// - state: A block of code to run to produce a fresh state variable for
/// each aggregation group. The block should return an
/// UnsafeMutablePointer to the fresh state variable.
public func createAggregation<T>(
_ functionName: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
step: @escaping ([Binding?], UnsafeMutablePointer<T>) -> Void,
final: @escaping (UnsafeMutablePointer<T>) -> Binding?,
state: @escaping () -> UnsafeMutablePointer<T>) {
let argc = argumentCount.map { Int($0) } ?? -1
let box: Aggregate = { (stepFlag: Int, context: Context, argc: Int32, argv: Argv) in
let nBytes = Int32(MemoryLayout<UnsafeMutablePointer<Int64>>.size)
guard let aggregateContext = sqlite3_aggregate_context(context, nBytes) else {
fatalError("Could not get aggregate context")
}
let mutablePointer = aggregateContext.assumingMemoryBound(to: UnsafeMutableRawPointer.self)
if stepFlag > 0 {
let arguments = argv.getBindings(argc: argc)
if aggregateContext.assumingMemoryBound(to: Int64.self).pointee == 0 {
mutablePointer.pointee = UnsafeMutableRawPointer(mutating: state())
}
step(arguments, mutablePointer.pointee.assumingMemoryBound(to: T.self))
} else {
let result = final(mutablePointer.pointee.assumingMemoryBound(to: T.self))
context.set(result: result)
}
}
func xStep(context: Context, argc: Int32, value: Argv) {
unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)(1, context, argc, value)
}
func xFinal(context: Context) {
unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)(0, context, 0, nil)
}
let flags = SQLITE_UTF8 | (deterministic ? SQLITE_DETERMINISTIC : 0)
let resultCode = sqlite3_create_function_v2(
handle,
functionName,
Int32(argc),
flags,
/* pApp */ unsafeBitCast(box, to: UnsafeMutableRawPointer.self),
/* xFunc */ nil, xStep, xFinal, /* xDestroy */ nil
)
if let result = Result(errorCode: resultCode, connection: self) {
fatalError("Error creating function: \(result)")
}
register(functionName, argc: argc, value: box)
}
public func createAggregation<T: AnyObject>(
_ aggregate: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
initialValue: T,
reduce: @escaping (T, [Binding?]) -> T,
result: @escaping (T) -> Binding?
) {
let step: ([Binding?], UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Void = { (bindings, ptr) in
let pointer = ptr.pointee.assumingMemoryBound(to: T.self)
let current = Unmanaged<T>.fromOpaque(pointer).takeRetainedValue()
let next = reduce(current, bindings)
ptr.pointee = Unmanaged.passRetained(next).toOpaque()
}
let final: (UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Binding? = { ptr in
let pointer = ptr.pointee.assumingMemoryBound(to: T.self)
let obj = Unmanaged<T>.fromOpaque(pointer).takeRetainedValue()
let value = result(obj)
ptr.deallocate()
return value
}
let state: () -> UnsafeMutablePointer<UnsafeMutableRawPointer> = {
let pointer = UnsafeMutablePointer<UnsafeMutableRawPointer>.allocate(capacity: 1)
pointer.pointee = Unmanaged.passRetained(initialValue).toOpaque()
return pointer
}
createAggregation(aggregate, step: step, final: final, state: state)
}
public func createAggregation<T>(
_ aggregate: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
initialValue: T,
reduce: @escaping (T, [Binding?]) -> T,
result: @escaping (T) -> Binding?
) {
let step: ([Binding?], UnsafeMutablePointer<T>) -> Void = { (bindings, pointer) in
let current = pointer.pointee
let next = reduce(current, bindings)
pointer.pointee = next
}
let final: (UnsafeMutablePointer<T>) -> Binding? = { pointer in
let value = result(pointer.pointee)
pointer.deallocate()
return value
}
let state: () -> UnsafeMutablePointer<T> = {
let pointer = UnsafeMutablePointer<T>.allocate(capacity: 1)
pointer.initialize(to: initialValue)
return pointer
}
createAggregation(aggregate, step: step, final: final, state: state)
}
}