|
| 1 | +--- |
| 2 | +layout: post |
| 3 | +title: "ARKit+Swift 版本的机器学习算法 k-NN" |
| 4 | +author: iosdevlog |
| 5 | +date: 2019-04-08 21:18:56 +0800 |
| 6 | +description: "" |
| 7 | +category: 机器学习 |
| 8 | +tags: [knn] |
| 9 | +--- |
| 10 | + |
| 11 | + |
| 12 | + |
| 13 | + |
| 14 | + |
| 15 | + |
| 16 | + |
| 17 | +# 维基介绍 |
| 18 | + |
| 19 | +在[模式识别](https://zh.wikipedia.org/wiki/%E6%A8%A1%E5%BC%8F%E8%AF%86%E5%88%AB "模式识别")领域中,**最近邻居法**(**KNN**算法,又译**K-近邻算法**)是一种用于[分类](https://zh.wikipedia.org/wiki/%E5%88%86%E7%B1%BB%E9%97%AE%E9%A2%98 "分类问题")和[回归](https://zh.wikipedia.org/wiki/%E8%BF%B4%E6%AD%B8%E5%88%86%E6%9E%90 "回归分析")的[非参数统计](https://zh.wikipedia.org/wiki/%E7%84%A1%E6%AF%8D%E6%95%B8%E7%B5%B1%E8%A8%88 "非参数统计")方法<sup>[[1]](https://zh.wikipedia.org/wiki/%E6%9C%80%E8%BF%91%E9%84%B0%E5%B1%85%E6%B3%95#cite_note-1)</sup>。在这两种情况下,输入包含[特征空间(Feature Space)](https://zh.wikipedia.org/w/index.php?title=%E7%89%B9%E5%BE%B5%E7%A9%BA%E9%96%93(%E6%A9%9F%E5%99%A8%E5%AD%B8%E7%BF%92)&action=edit&redlink=1)中的***k***个最接近的训练样本。 |
| 20 | + |
| 21 | +* 在*k-NN分类*中,输出是一个分类族群。一个对象的分类是由其邻居的“多数表决”确定的,*k*个最近邻居(*k*为正[整数](https://zh.wikipedia.org/wiki/%E6%95%B4%E6%95%B0 "整数"),通常较小)中最常见的分类决定了赋予该对象的类别。若*k* = 1,则该对象的类别直接由最近的一个节点赋予。 |
| 22 | + |
| 23 | +* 在*k-NN回归*中,输出是该对象的属性值。该值是其*k*个最近邻居的值的平均值。 |
| 24 | + |
| 25 | + |
| 26 | +最近邻居法采用向量空间模型来分类,概念为相同类别的案例,彼此的相似度高,而可以借由计算与已知类别案例之相似度,来评估未知类别案例可能的分类。 |
| 27 | + |
| 28 | +K-NN是一种[基于实例的学习](https://zh.wikipedia.org/w/index.php?title=%E5%9F%BA%E4%BA%8E%E5%AE%9E%E4%BE%8B%E7%9A%84%E5%AD%A6%E4%B9%A0&action=edit&redlink=1),或者是局部近似和将所有计算推迟到分类之后的[惰性学习](https://zh.wikipedia.org/w/index.php?title=%E6%83%B0%E6%80%A7%E5%AD%A6%E4%B9%A0&action=edit&redlink=1)。k-近邻算法是所有的[机器学习](https://zh.wikipedia.org/wiki/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0 "机器学习")算法中最简单的之一。 |
| 29 | + |
| 30 | +无论是分类还是回归,衡量邻居的权重都非常有用,使较近邻居的权重比较远邻居的权重大。例如,一种常见的加权方案是给每个邻居权重赋值为1/ d,其中d是到邻居的距离。<sup>[[注 1]](https://zh.wikipedia.org/wiki/%E6%9C%80%E8%BF%91%E9%84%B0%E5%B1%85%E6%B3%95#cite_note-2)</sup> |
| 31 | + |
| 32 | +邻居都取自一组已经正确分类(在回归的情况下,指属性值正确)的对象。虽然没要求明确的训练步骤,但这也可以当作是此算法的一个训练样本集。 |
| 33 | + |
| 34 | +k-近邻算法的缺点是对数据的局部结构非常敏感。本算法与[K-平均算法](https://zh.wikipedia.org/wiki/K-%E5%B9%B3%E5%9D%87%E7%AE%97%E6%B3%95 "K-平均算法")(另一流行的机器学习技术)没有任何关系,请勿与之混淆。 |
| 35 | + |
| 36 | +# ARKit + Swift + k-NN 实现 |
| 37 | + |
| 38 | +创建 KNN 类(结构体 `struct` 也行,我是为了 与 `sklearn` 尽量一致)。 |
| 39 | + |
| 40 | +``` |
| 41 | +/// KNN |
| 42 | +public struct KNN<Feature, Label: Hashable> { |
| 43 | +} |
| 44 | +``` |
| 45 | + |
| 46 | +属性 |
| 47 | + |
| 48 | +```swift |
| 49 | + /// Number of neighbors to use by default for :meth:`kneighbors` queries |
| 50 | + private var k: Int |
| 51 | + /// Data set |
| 52 | + private var X = [Feature]() |
| 53 | + /// Target values |
| 54 | + private var y = [Label]() |
| 55 | + |
| 56 | + |
| 57 | + /// distance |
| 58 | + private let distanceMetric: (_ x1: Feature, _ x2: Feature) -> Double |
| 59 | + /// draw radius for debug |
| 60 | + public var debugRadiusCallback: (([Double]) -> ())? = nil |
| 61 | +``` |
| 62 | +数据: |
| 63 | + |
| 64 | +* `k`: 指定取 k 个最接近的训练样本 |
| 65 | +* `X`: 样本特征 (数组)一般要传数组的数组 |
| 66 | +* `y`: 样本标签 (数组) |
| 67 | + |
| 68 | +辅助: |
| 69 | + |
| 70 | +* `distanceMetric`: 用来计算距离的函数 |
| 71 | +* `debugRadiusCallback`: 调度时候用,主要是画出最近的 k 个样本范围 |
| 72 | + |
| 73 | +# 方法 |
| 74 | + |
| 75 | +```swift |
| 76 | + /// constructorLabels for xTest |
| 77 | + /// |
| 78 | + /// - Parameters: |
| 79 | + /// - k: k |
| 80 | + /// - distanceMetric: distance |
| 81 | + public init (k: Int, distanceMetric: @escaping (_ x1: Feature, _ x2: Feature) -> Double) |
| 82 | + |
| 83 | + /// train |
| 84 | + /// |
| 85 | + /// - Parameters: |
| 86 | + /// - X: Training set |
| 87 | + /// - y: Target values |
| 88 | + public mutating func fit(X: [Feature], y: [Label]) |
| 89 | + |
| 90 | + |
| 91 | + /// Labels for xTest |
| 92 | + /// |
| 93 | + /// - Parameter XTest: Test set |
| 94 | + /// - Returns: Target values |
| 95 | + public func predict(XTest: [Feature]) -> [Label] |
| 96 | +``` |
| 97 | + |
| 98 | +* `init()`: 构造函数 需要预先决定 `k` 和距离计算方法 |
| 99 | +* `fit()`: 拟合目标函数,kNN 不需要拟合,只要记下数据即可 |
| 100 | +* `predict()`: 预测给定的特征,返回对应的标签 |
| 101 | + |
| 102 | +## 距离计算 |
| 103 | + |
| 104 | +```swift |
| 105 | +public struct Distance { |
| 106 | + |
| 107 | + /// Helper function to calculate euclidian distance |
| 108 | + /// |
| 109 | + /// - Parameters: |
| 110 | + /// - x0: source coordinate |
| 111 | + /// - x1: target coordinate |
| 112 | + /// - Returns: euclidian distance |
| 113 | + public static func euclideanDistance(_ x0: [Double], _ x1: [Double]) -> Double |
| 114 | + |
| 115 | + // Convenience |
| 116 | + public static func euclideanDistance() -> (([Double], [Double]) -> Double) { |
| 117 | + return { self.euclideanDistance($0, $1) } |
| 118 | + } |
| 119 | +``` |
| 120 | + |
| 121 | +这里使用 欧式距离(Euclidean Distance) |
| 122 | + |
| 123 | +KNN 使用: |
| 124 | + |
| 125 | +```swift |
| 126 | + func testKNN() { |
| 127 | + let kNN = KNN<Double, Int>(k: 3) |
| 128 | + let X = [[1.0], [2], [3], [4]] |
| 129 | + let y = [0, 0, 1, 1] |
| 130 | + kNN.fit(X, y) |
| 131 | + |
| 132 | + let label = kNN.predict([[1.2], [3]]) |
| 133 | + |
| 134 | + XCTAssertEqual([0, 1], label) |
| 135 | + } |
| 136 | +``` |
| 137 | + |
| 138 | +## 显示 2 维 |
| 139 | + |
| 140 | +```swift |
| 141 | +enum MLStep: Int { |
| 142 | + case train |
| 143 | + case predict |
| 144 | +} |
| 145 | + |
| 146 | +enum GeometryType: Int { |
| 147 | + case box |
| 148 | + case pyramid |
| 149 | + case sphere |
| 150 | + case cylinder |
| 151 | + case cone |
| 152 | + case tube |
| 153 | + case torus |
| 154 | +... |
| 155 | +} |
| 156 | +``` |
| 157 | + |
| 158 | +* `MLStep`: 分成 训练 和 预测 ,训练一次,可以一直预测。 |
| 159 | +* `GeometryType`: 定义几种形状,这里定义给 `ARKIT` 使用的 |
| 160 | + |
| 161 | +## KNNViewController |
| 162 | + |
| 163 | +```swift |
| 164 | +class KNNViewController: UIViewController { |
| 165 | + |
| 166 | + let radius: CGFloat = 5 |
| 167 | + |
| 168 | + public var X: [[Double]] = [] |
| 169 | + public var y: [GeometryType] = [] |
| 170 | + public var XTest: [[Double]] = [] |
| 171 | + public var yTest: [GeometryType] = [] |
| 172 | + public var radiuses: [Double] = [] { |
| 173 | + didSet { |
| 174 | + for (center, r) in zip(XTest, radiuses) { |
| 175 | + drawCircle(center: CGPoint(x: center[0], y: center[1]), radius: CGFloat(r)) |
| 176 | + } |
| 177 | + } |
| 178 | + } |
| 179 | + public var predictLayers: [CALayer] = [] |
| 180 | + |
| 181 | + var model = KNN<[Double], GeometryType>(k: 1, distanceMetric: Distance.euclideanDistance()) |
| 182 | + |
| 183 | + @IBOutlet weak var kNNPickerView: UIPickerView! |
| 184 | + @IBOutlet weak var panelView: UIView! |
| 185 | + @IBOutlet weak var trainBarButtonItem: UIBarButtonItem! |
| 186 | + |
| 187 | + var mlStep = MLStep.train { |
| 188 | + didSet { |
| 189 | + switch mlStep { |
| 190 | + case .train: |
| 191 | + trainBarButtonItem.title = "train" |
| 192 | + default: |
| 193 | + trainBarButtonItem.title = "predict" |
| 194 | + } |
| 195 | + } |
| 196 | + } |
| 197 | +... |
| 198 | +} |
| 199 | +``` |
| 200 | + |
| 201 | +添加 `Layer` 到 `panelView` 上实现类别,2D 只分了三个类别,分别用 方形(红),三角形(蓝),花形(绿)表示。 |
| 202 | + |
| 203 | +使用 `alpha` 表示预测类别,以预测样本为中心画一个圈,圈内为最近的 `k` 个样本。 |
| 204 | + |
| 205 | +```swift |
| 206 | + func drawCircle(center: CGPoint, radius: CGFloat, alpha: CGFloat = 0.1) { |
| 207 | + let r = self.radius + radius |
| 208 | + let kNNCircleLayer = CAShapeLayer() |
| 209 | + kNNCircleLayer.path = UIBezierPath(roundedRect: CGRect(x: center.x - r, y: center.y - r, width: r * 2, height: r * 2), cornerRadius: r).cgPath |
| 210 | + kNNCircleLayer.fillColor = UIColor(red: 0.1, green: 0.1, blue: 0.1, alpha: alpha).cgColor |
| 211 | + kNNCircleLayer.borderColor = UIColor(red: 0.8, green: 0.8, blue: 0.8, alpha: 1).cgColor |
| 212 | + kNNCircleLayer.borderWidth = 1 |
| 213 | + panelView.layer.addSublayer(kNNCircleLayer) |
| 214 | + } |
| 215 | +``` |
| 216 | + |
| 217 | +圆内为 `k` 个样本。 |
| 218 | + |
| 219 | + |
| 220 | + |
| 221 | +# ARKit 实现 |
| 222 | + |
| 223 | +能 3D 展示多好,别急,下面就是用 `ARKit` 实现的 3D 版本。 |
| 224 | + |
| 225 | +```swift |
| 226 | +class ARKNNViewController: UIViewController |
| 227 | +``` |
| 228 | + |
| 229 | +基本实现和 `ARKNNViewController` 类似。 |
| 230 | + |
| 231 | +```swift |
| 232 | + func drawSphere(center: SCNVector3, radius: Float) { |
| 233 | + let geometry = SCNSphere(radius: CGFloat(radius) + self.radius) |
| 234 | + |
| 235 | + geometry.firstMaterial?.diffuse.contents = UIColor(red: 0.1, green: 0.1, blue: 0.8, alpha: 0.7) |
| 236 | + |
| 237 | + let node = SCNNode() |
| 238 | + node.geometry = geometry |
| 239 | + node.position = center |
| 240 | + sceneView.scene.rootNode.addChildNode(node) |
| 241 | + } |
| 242 | +``` |
| 243 | + |
| 244 | +这是画最外面的范围球体,球体内为 `k` 个样本。 |
| 245 | + |
| 246 | + |
| 247 | + |
| 248 | +# 视频 |
| 249 | + |
| 250 | +b站视频:[https://www.bilibili.com/video/av48647681/](https://www.bilibili.com/video/av48647681/) |
0 commit comments