PhotoClassifier.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import Photos
  2. import Vision
  3. class PhotoClassifier {
  4. struct ClassifiedPhotos {
  5. var screenshots: [PHAsset] = []
  6. var locations: [String: [PHAsset]] = [:] // 按地点分组
  7. var people: [String: [PHAsset]] = [:] // 按人物分组
  8. var similarPhotos: [[PHAsset]] = [] // 存储相似照片组
  9. }
  10. func classifyPhotos(
  11. assets: PHFetchResult<PHAsset>,
  12. progressHandler: @escaping (String, Float) -> Void,
  13. completion: @escaping (ClassifiedPhotos) -> Void
  14. ) {
  15. // 在后台队列处理
  16. DispatchQueue.global(qos: .userInitiated).async {
  17. var result = ClassifiedPhotos()
  18. let group = DispatchGroup()
  19. // 开始处理
  20. DispatchQueue.main.async {
  21. progressHandler("正在加载照片...", 0.0)
  22. }
  23. // 1. 检测截图 (占总进度的 20%)
  24. group.enter()
  25. self.fetchScreenshots(from: assets) { screenshots in
  26. result.screenshots = screenshots
  27. DispatchQueue.main.async {
  28. progressHandler("正在检测截图...", 0.2)
  29. }
  30. group.leave()
  31. }
  32. // 2. 检测相似照片 (占总进度的 80%)
  33. group.enter()
  34. self.detectSimilarPhotos(
  35. assets: assets,
  36. progressHandler: { stage, progress in
  37. // 将相似照片检测的进度映射到 20%-100% 的范围
  38. let mappedProgress = 0.2 + (progress * 0.8)
  39. DispatchQueue.main.async {
  40. progressHandler(stage, mappedProgress)
  41. }
  42. }
  43. ) { similarPhotos in
  44. result.similarPhotos = similarPhotos
  45. group.leave()
  46. }
  47. // 等待所有处理完成
  48. group.notify(queue: .main) {
  49. progressHandler("分类完成", 1.0)
  50. completion(result)
  51. }
  52. }
  53. }
  54. private func detectSimilarPhotos(
  55. assets: PHFetchResult<PHAsset>,
  56. progressHandler: @escaping (String, Float) -> Void,
  57. completion: @escaping ([[PHAsset]]) -> Void
  58. ) {
  59. var similarGroups: [[PHAsset]] = []
  60. let group = DispatchGroup()
  61. if #available(iOS 13.0, *) {
  62. var imageFeatures: [(asset: PHAsset, feature: VNFeaturePrintObservation)] = []
  63. // 创建处理队列
  64. let processingQueue = DispatchQueue(label: "com.app.similarPhotos", qos: .userInitiated)
  65. let semaphore = DispatchSemaphore(value: 5)
  66. // 1. 提取所有图片的特征
  67. let totalAssets = assets.count
  68. var processedAssets = 0
  69. progressHandler("正在加载照片...", 0.0)
  70. for i in 0..<assets.count {
  71. let asset = assets[i]
  72. group.enter()
  73. semaphore.wait()
  74. let options = PHImageRequestOptions()
  75. options.deliveryMode = .highQualityFormat
  76. options.isSynchronous = false
  77. options.resizeMode = .exact
  78. PHImageManager.default().requestImage(
  79. for: asset,
  80. targetSize: CGSize(width: 448, height: 448),
  81. contentMode: .aspectFit,
  82. options: options
  83. ) { image, _ in
  84. defer {
  85. semaphore.signal()
  86. }
  87. guard let image = image,
  88. let cgImage = image.cgImage else {
  89. group.leave()
  90. return
  91. }
  92. processingQueue.async {
  93. do {
  94. let requestHandler = VNImageRequestHandler(cgImage: cgImage, options: [:])
  95. if #available(iOS 13.0, *) {
  96. let request = VNGenerateImageFeaturePrintRequest()
  97. try requestHandler.perform([request])
  98. if let result = request.results?.first as? VNFeaturePrintObservation {
  99. imageFeatures.append((asset, result))
  100. // 更新特征提取进度
  101. processedAssets += 1
  102. let progress = Float(processedAssets) / Float(totalAssets)
  103. progressHandler("正在提取特征...", progress * 0.6)
  104. }
  105. } else {
  106. // Fallback on earlier versions
  107. }
  108. } catch {
  109. print("特征提取失败: \(error)")
  110. }
  111. group.leave()
  112. }
  113. }
  114. }
  115. // 2. 比较特征相似度并分组
  116. group.notify(queue: processingQueue) {
  117. progressHandler("正在比较相似度...", 0.6)
  118. // 近似度
  119. let similarityThreshold: Float = 0.7
  120. var processedComparisons = 0
  121. let totalComparisons = (imageFeatures.count * (imageFeatures.count - 1)) / 2
  122. var processedIndices = Set<Int>()
  123. for i in 0..<imageFeatures.count {
  124. if processedIndices.contains(i) { continue }
  125. var similarGroup: [PHAsset] = [imageFeatures[i].asset]
  126. processedIndices.insert(i)
  127. for j in (i + 1)..<imageFeatures.count {
  128. if processedIndices.contains(j) { continue }
  129. do {
  130. var distance: Float = 0
  131. try imageFeatures[i].feature.computeDistance(&distance, to: imageFeatures[j].feature)
  132. let similarity = 1 - distance
  133. if similarity >= similarityThreshold {
  134. similarGroup.append(imageFeatures[j].asset)
  135. processedIndices.insert(j)
  136. }
  137. // 更新比较进度
  138. processedComparisons += 1
  139. let compareProgress = Float(processedComparisons) / Float(totalComparisons)
  140. progressHandler("正在比较相似度...", 0.6 + compareProgress * 0.4)
  141. } catch {
  142. print("相似度计算失败: \(error)")
  143. }
  144. }
  145. if similarGroup.count > 1 {
  146. similarGroups.append(similarGroup)
  147. }
  148. }
  149. // 按照照片数量降序排序
  150. similarGroups.sort { $0.count > $1.count }
  151. DispatchQueue.main.async {
  152. completion(similarGroups)
  153. }
  154. }
  155. } else {
  156. // Fallback on earlier versions
  157. }
  158. }
  159. // 按地点分类
  160. private func classifyByLocation(assets: PHFetchResult<PHAsset>,
  161. completion: @escaping ([String: [PHAsset]]) -> Void) {
  162. var locationGroups: [String: [PHAsset]] = [:]
  163. let group = DispatchGroup()
  164. let geocodeQueue = DispatchQueue(label: "com.app.geocoding")
  165. let semaphore = DispatchSemaphore(value: 10) // 限制并发请求数
  166. assets.enumerateObjects { asset, _, _ in
  167. if let location = asset.location {
  168. group.enter()
  169. semaphore.wait()
  170. geocodeQueue.async {
  171. let geocoder = CLGeocoder()
  172. geocoder.reverseGeocodeLocation(location) { placemarks, error in
  173. defer {
  174. semaphore.signal()
  175. group.leave()
  176. }
  177. if let placemark = placemarks?.first {
  178. let locationName = self.formatLocationName(placemark)
  179. DispatchQueue.main.async {
  180. if locationGroups[locationName] == nil {
  181. locationGroups[locationName] = []
  182. }
  183. locationGroups[locationName]?.append(asset)
  184. }
  185. }
  186. }
  187. }
  188. }
  189. }
  190. // 等待所有地理编码完成后回调
  191. group.notify(queue: .main) {
  192. completion(locationGroups)
  193. }
  194. }
  195. // 格式化地点名称(只返回城市名)
  196. private func formatLocationName(_ placemark: CLPlacemark) -> String {
  197. if let city = placemark.locality {
  198. return city
  199. }
  200. return "其他"
  201. }
  202. // 按人物分类
  203. private func classifyByPeople(assets: PHFetchResult<PHAsset>,
  204. completion: @escaping ([String: [PHAsset]]) -> Void) {
  205. var peopleGroups: [String: [PHAsset]] = [:]
  206. let group = DispatchGroup()
  207. // 创建一个数组来存储检测到人脸的照片
  208. var facesArray: [PHAsset] = []
  209. // 遍历所有照片
  210. assets.enumerateObjects { asset, _, _ in
  211. group.enter()
  212. // 获取照片的缩略图进行人脸检测
  213. let options = PHImageRequestOptions()
  214. options.isSynchronous = false
  215. options.deliveryMode = .fastFormat
  216. PHImageManager.default().requestImage(
  217. for: asset,
  218. targetSize: CGSize(width: 500, height: 500), // 使用较小的尺寸提高性能
  219. contentMode: .aspectFit,
  220. options: options
  221. ) { image, _ in
  222. guard let image = image else {
  223. group.leave()
  224. return
  225. }
  226. // 使用 Vision 框架检测人脸
  227. guard let ciImage = CIImage(image: image) else {
  228. group.leave()
  229. return
  230. }
  231. let request = VNDetectFaceRectanglesRequest()
  232. let handler = VNImageRequestHandler(ciImage: ciImage)
  233. do {
  234. try handler.perform([request])
  235. if let results = request.results, !results.isEmpty {
  236. // 检测到人脸,添加到数组
  237. DispatchQueue.main.async {
  238. facesArray.append(asset)
  239. }
  240. }
  241. } catch {
  242. print("人脸检测失败: \(error)")
  243. }
  244. group.leave()
  245. }
  246. }
  247. // 等待所有检测完成后更新结果
  248. group.notify(queue: .main) {
  249. if !facesArray.isEmpty {
  250. peopleGroups["包含人脸的照片"] = facesArray
  251. }
  252. completion(peopleGroups)
  253. }
  254. }
  255. // 识别截图
  256. private func fetchScreenshots(from assets: PHFetchResult<PHAsset>,
  257. completion: @escaping ([PHAsset]) -> Void) {
  258. var screenshots: [PHAsset] = []
  259. // 获取系统的截图智能相册
  260. let screenshotAlbums = PHAssetCollection.fetchAssetCollections(
  261. with: .smartAlbum,
  262. subtype: .smartAlbumScreenshots,
  263. options: nil
  264. )
  265. // 从截图相册中获取所有截图
  266. screenshotAlbums.enumerateObjects { collection, _, _ in
  267. let fetchOptions = PHFetchOptions()
  268. let screenshotAssets = PHAsset.fetchAssets(in: collection, options: fetchOptions)
  269. screenshotAssets.enumerateObjects { asset, _, _ in
  270. screenshots.append(asset)
  271. }
  272. }
  273. completion(screenshots)
  274. }
  275. }