PhotoClassifier.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import Photos
  2. import Vision
  3. class PhotoClassifier {
  4. struct ClassifiedPhotos {
  5. var screenshots: [PHAsset] = []
  6. var locations: [String: [PHAsset]] = [:] // 按地点分组
  7. var people: [String: [PHAsset]] = [:] // 按人物分组
  8. var similarPhotos: [[PHAsset]] = [] // 存储相似照片组
  9. }
  10. func classifyPhotos(
  11. assets: PHFetchResult<PHAsset>,
  12. progressHandler: @escaping (String, Float) -> Void,
  13. completion: @escaping (ClassifiedPhotos) -> Void
  14. ) {
  15. // 在后台队列处理
  16. DispatchQueue.global(qos: .userInitiated).async {
  17. var result = ClassifiedPhotos()
  18. let group = DispatchGroup()
  19. // 开始处理
  20. DispatchQueue.main.async {
  21. progressHandler("正在加载照片...", 0.0)
  22. }
  23. // 1. 检测截图 (占总进度的 20%)
  24. group.enter()
  25. self.fetchScreenshots(from: assets) { screenshots in
  26. result.screenshots = screenshots
  27. DispatchQueue.main.async {
  28. progressHandler("正在检测截图...", 0.2)
  29. }
  30. group.leave()
  31. }
  32. // 2. 检测相似照片 (占总进度的 80%)
  33. group.enter()
  34. self.detectSimilarPhotos(
  35. assets: assets,
  36. progressHandler: { stage, progress in
  37. // 将相似照片检测的进度映射到 20%-100% 的范围
  38. let mappedProgress = 0.2 + (progress * 0.8)
  39. DispatchQueue.main.async {
  40. progressHandler(stage, mappedProgress)
  41. }
  42. }
  43. ) { similarPhotos in
  44. result.similarPhotos = similarPhotos
  45. group.leave()
  46. }
  47. // 等待所有处理完成
  48. group.notify(queue: .main) {
  49. progressHandler("分类完成", 1.0)
  50. completion(result)
  51. }
  52. }
  53. }
  54. private func detectSimilarPhotos(
  55. assets: PHFetchResult<PHAsset>,
  56. progressHandler: @escaping (String, Float) -> Void,
  57. completion: @escaping ([[PHAsset]]) -> Void
  58. ) {
  59. var similarGroups: [[PHAsset]] = []
  60. let group = DispatchGroup()
  61. var imageFeatures: [(asset: PHAsset, feature: VNFeaturePrintObservation)] = []
  62. // 创建处理队列
  63. let processingQueue = DispatchQueue(label: "com.app.similarPhotos", qos: .userInitiated)
  64. let semaphore = DispatchSemaphore(value: 5)
  65. // 1. 提取所有图片的特征
  66. let totalAssets = assets.count
  67. var processedAssets = 0
  68. progressHandler("正在加载照片...", 0.0)
  69. for i in 0..<assets.count {
  70. let asset = assets[i]
  71. group.enter()
  72. semaphore.wait()
  73. let options = PHImageRequestOptions()
  74. options.deliveryMode = .highQualityFormat
  75. options.isSynchronous = false
  76. options.resizeMode = .exact
  77. PHImageManager.default().requestImage(
  78. for: asset,
  79. targetSize: CGSize(width: 448, height: 448),
  80. contentMode: .aspectFit,
  81. options: options
  82. ) { image, _ in
  83. defer {
  84. semaphore.signal()
  85. }
  86. guard let image = image,
  87. let cgImage = image.cgImage else {
  88. group.leave()
  89. return
  90. }
  91. processingQueue.async {
  92. do {
  93. let requestHandler = VNImageRequestHandler(cgImage: cgImage, options: [:])
  94. let request = VNGenerateImageFeaturePrintRequest()
  95. try requestHandler.perform([request])
  96. if let result = request.results?.first as? VNFeaturePrintObservation {
  97. imageFeatures.append((asset, result))
  98. // 更新特征提取进度
  99. processedAssets += 1
  100. let progress = Float(processedAssets) / Float(totalAssets)
  101. progressHandler("正在提取特征...", progress * 0.6)
  102. }
  103. } catch {
  104. print("特征提取失败: \(error)")
  105. }
  106. group.leave()
  107. }
  108. }
  109. }
  110. // 2. 比较特征相似度并分组
  111. group.notify(queue: processingQueue) {
  112. progressHandler("正在比较相似度...", 0.6)
  113. // 近似度
  114. let similarityThreshold: Float = 0.7
  115. var processedComparisons = 0
  116. let totalComparisons = (imageFeatures.count * (imageFeatures.count - 1)) / 2
  117. var processedIndices = Set<Int>()
  118. for i in 0..<imageFeatures.count {
  119. if processedIndices.contains(i) { continue }
  120. var similarGroup: [PHAsset] = [imageFeatures[i].asset]
  121. processedIndices.insert(i)
  122. for j in (i + 1)..<imageFeatures.count {
  123. if processedIndices.contains(j) { continue }
  124. do {
  125. var distance: Float = 0
  126. try imageFeatures[i].feature.computeDistance(&distance, to: imageFeatures[j].feature)
  127. let similarity = 1 - distance
  128. if similarity >= similarityThreshold {
  129. similarGroup.append(imageFeatures[j].asset)
  130. processedIndices.insert(j)
  131. }
  132. // 更新比较进度
  133. processedComparisons += 1
  134. let compareProgress = Float(processedComparisons) / Float(totalComparisons)
  135. progressHandler("正在比较相似度...", 0.6 + compareProgress * 0.4)
  136. } catch {
  137. print("相似度计算失败: \(error)")
  138. }
  139. }
  140. if similarGroup.count > 1 {
  141. similarGroups.append(similarGroup)
  142. }
  143. }
  144. // 按照照片数量降序排序
  145. similarGroups.sort { $0.count > $1.count }
  146. DispatchQueue.main.async {
  147. completion(similarGroups)
  148. }
  149. }
  150. }
  151. // 按地点分类
  152. private func classifyByLocation(assets: PHFetchResult<PHAsset>,
  153. completion: @escaping ([String: [PHAsset]]) -> Void) {
  154. var locationGroups: [String: [PHAsset]] = [:]
  155. let group = DispatchGroup()
  156. let geocodeQueue = DispatchQueue(label: "com.app.geocoding")
  157. let semaphore = DispatchSemaphore(value: 10) // 限制并发请求数
  158. assets.enumerateObjects { asset, _, _ in
  159. if let location = asset.location {
  160. group.enter()
  161. semaphore.wait()
  162. geocodeQueue.async {
  163. let geocoder = CLGeocoder()
  164. geocoder.reverseGeocodeLocation(location) { placemarks, error in
  165. defer {
  166. semaphore.signal()
  167. group.leave()
  168. }
  169. if let placemark = placemarks?.first {
  170. let locationName = self.formatLocationName(placemark)
  171. DispatchQueue.main.async {
  172. if locationGroups[locationName] == nil {
  173. locationGroups[locationName] = []
  174. }
  175. locationGroups[locationName]?.append(asset)
  176. }
  177. }
  178. }
  179. }
  180. }
  181. }
  182. // 等待所有地理编码完成后回调
  183. group.notify(queue: .main) {
  184. completion(locationGroups)
  185. }
  186. }
  187. // 格式化地点名称(只返回城市名)
  188. private func formatLocationName(_ placemark: CLPlacemark) -> String {
  189. if let city = placemark.locality {
  190. return city
  191. }
  192. return "其他"
  193. }
  194. // 按人物分类
  195. private func classifyByPeople(assets: PHFetchResult<PHAsset>,
  196. completion: @escaping ([String: [PHAsset]]) -> Void) {
  197. var peopleGroups: [String: [PHAsset]] = [:]
  198. let group = DispatchGroup()
  199. // 创建一个数组来存储检测到人脸的照片
  200. var facesArray: [PHAsset] = []
  201. // 遍历所有照片
  202. assets.enumerateObjects { asset, _, _ in
  203. group.enter()
  204. // 获取照片的缩略图进行人脸检测
  205. let options = PHImageRequestOptions()
  206. options.isSynchronous = false
  207. options.deliveryMode = .fastFormat
  208. PHImageManager.default().requestImage(
  209. for: asset,
  210. targetSize: CGSize(width: 500, height: 500), // 使用较小的尺寸提高性能
  211. contentMode: .aspectFit,
  212. options: options
  213. ) { image, _ in
  214. guard let image = image else {
  215. group.leave()
  216. return
  217. }
  218. // 使用 Vision 框架检测人脸
  219. guard let ciImage = CIImage(image: image) else {
  220. group.leave()
  221. return
  222. }
  223. let request = VNDetectFaceRectanglesRequest()
  224. let handler = VNImageRequestHandler(ciImage: ciImage)
  225. do {
  226. try handler.perform([request])
  227. if let results = request.results, !results.isEmpty {
  228. // 检测到人脸,添加到数组
  229. DispatchQueue.main.async {
  230. facesArray.append(asset)
  231. }
  232. }
  233. } catch {
  234. print("人脸检测失败: \(error)")
  235. }
  236. group.leave()
  237. }
  238. }
  239. // 等待所有检测完成后更新结果
  240. group.notify(queue: .main) {
  241. if !facesArray.isEmpty {
  242. peopleGroups["包含人脸的照片"] = facesArray
  243. }
  244. completion(peopleGroups)
  245. }
  246. }
  247. // 识别截图
  248. private func fetchScreenshots(from assets: PHFetchResult<PHAsset>,
  249. completion: @escaping ([PHAsset]) -> Void) {
  250. var screenshots: [PHAsset] = []
  251. // 获取系统的截图智能相册
  252. let screenshotAlbums = PHAssetCollection.fetchAssetCollections(
  253. with: .smartAlbum,
  254. subtype: .smartAlbumScreenshots,
  255. options: nil
  256. )
  257. // 从截图相册中获取所有截图
  258. screenshotAlbums.enumerateObjects { collection, _, _ in
  259. let fetchOptions = PHFetchOptions()
  260. let screenshotAssets = PHAsset.fetchAssets(in: collection, options: fetchOptions)
  261. screenshotAssets.enumerateObjects { asset, _, _ in
  262. screenshots.append(asset)
  263. }
  264. }
  265. completion(screenshots)
  266. }
  267. }