ClassifyPhoto.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import Photos
  2. import Vision
  3. class ClassifyPhoto {
  4. struct ClassifiedPhotos {
  5. var screenshots: [PHAsset] = []
  6. var locations: [String: [PHAsset]] = [:] // 按地点分组
  7. var people: [String: [PHAsset]] = [:] // 按人物分组
  8. var similarPhotos: [[PHAsset]] = [] // 存储相似照片组
  9. }
  10. func classifyPhotos(
  11. assets: PHFetchResult<PHAsset>,
  12. progressHandler: @escaping (String, Float) -> Void,
  13. completion: @escaping (ClassifiedPhotos) -> Void
  14. ) {
  15. // 在后台队列处理
  16. DispatchQueue.global(qos: .userInitiated).async {
  17. var result = ClassifiedPhotos()
  18. let group = DispatchGroup()
  19. // 开始处理
  20. DispatchQueue.main.async {
  21. progressHandler("正在加载照片...", 0.0)
  22. }
  23. // 1. 检测截图 (占总进度的 20%)
  24. group.enter()
  25. self.fetchScreenshots(from: assets) { screenshots in
  26. result.screenshots = screenshots
  27. DispatchQueue.main.async {
  28. progressHandler("正在检测截图...", 0.2)
  29. }
  30. group.leave()
  31. }
  32. // 2. 检测相似照片 (占总进度的 80%)
  33. group.enter()
  34. self.detectSimilarPhotos(
  35. assets: assets,
  36. progressHandler: { stage, progress in
  37. // 将相似照片检测的进度映射到 20%-100% 的范围
  38. let mappedProgress = 0.2 + (progress * 0.6)
  39. DispatchQueue.main.async {
  40. progressHandler(stage, mappedProgress)
  41. }
  42. }
  43. ) { similarPhotos in
  44. result.similarPhotos = similarPhotos
  45. group.leave()
  46. }
  47. // 3. 按地点分类 (占总进度的 20%)
  48. group.enter()
  49. self.classifyByLocation(assets: assets) { locationGroups in
  50. result.locations = locationGroups
  51. DispatchQueue.main.async {
  52. progressHandler("正在按地点分类...", 0.8)
  53. }
  54. group.leave()
  55. }
  56. // 4. 按人物分类 (占总进度的 20%)
  57. group.enter()
  58. self.classifyByPeople(assets: assets) { peopleGroups in
  59. result.people = peopleGroups
  60. DispatchQueue.main.async {
  61. progressHandler("正在按人物分类...", 1.0)
  62. }
  63. group.leave()
  64. }
  65. // 等待所有处理完成
  66. group.notify(queue: .main) {
  67. progressHandler("分类完成", 1.0)
  68. completion(result)
  69. }
  70. }
  71. }
  72. private func detectSimilarPhotos(
  73. assets: PHFetchResult<PHAsset>,
  74. progressHandler: @escaping (String, Float) -> Void,
  75. completion: @escaping ([[PHAsset]]) -> Void
  76. ) {
  77. var similarGroups: [[PHAsset]] = []
  78. let group = DispatchGroup()
  79. if #available(iOS 13.0, *) {
  80. var imageFeatures: [(asset: PHAsset, feature: VNFeaturePrintObservation)] = []
  81. // 创建处理队列
  82. let processingQueue = DispatchQueue(label: "com.app.similarPhotos", qos: .userInitiated)
  83. let semaphore = DispatchSemaphore(value: 5)
  84. // 1. 提取所有图片的特征
  85. let totalAssets = assets.count
  86. var processedAssets = 0
  87. progressHandler("正在加载照片...", 0.0)
  88. for i in 0..<assets.count {
  89. let asset = assets[i]
  90. group.enter()
  91. semaphore.wait()
  92. let options = PHImageRequestOptions()
  93. options.deliveryMode = .highQualityFormat
  94. options.isSynchronous = false
  95. options.resizeMode = .exact
  96. PHImageManager.default().requestImage(
  97. for: asset,
  98. targetSize: CGSize(width: 448, height: 448),
  99. contentMode: .aspectFit,
  100. options: options
  101. ) { image, _ in
  102. defer {
  103. semaphore.signal()
  104. }
  105. guard let image = image,
  106. let cgImage = image.cgImage else {
  107. group.leave()
  108. return
  109. }
  110. processingQueue.async {
  111. do {
  112. let requestHandler = VNImageRequestHandler(cgImage: cgImage, options: [:])
  113. let request = VNGenerateImageFeaturePrintRequest()
  114. try requestHandler.perform([request])
  115. if let result = request.results?.first as? VNFeaturePrintObservation {
  116. imageFeatures.append((asset, result))
  117. // 更新特征提取进度
  118. processedAssets += 1
  119. let progress = Float(processedAssets) / Float(totalAssets)
  120. progressHandler("正在提取特征...", progress * 0.6)
  121. }
  122. } catch {
  123. print("特征提取失败: \(error)")
  124. }
  125. group.leave()
  126. }
  127. }
  128. }
  129. // 2. 比较特征相似度并分组
  130. group.notify(queue: processingQueue) {
  131. progressHandler("正在比较相似度...", 0.6)
  132. // 近似度
  133. let similarityThreshold: Float = 0.7
  134. var processedComparisons = 0
  135. let totalComparisons = (imageFeatures.count * (imageFeatures.count - 1)) / 2
  136. var processedIndices = Set<Int>()
  137. for i in 0..<imageFeatures.count {
  138. if processedIndices.contains(i) { continue }
  139. var similarGroup: [PHAsset] = [imageFeatures[i].asset]
  140. processedIndices.insert(i)
  141. for j in (i + 1)..<imageFeatures.count {
  142. if processedIndices.contains(j) { continue }
  143. do {
  144. var distance: Float = 0
  145. try imageFeatures[i].feature.computeDistance(&distance, to: imageFeatures[j].feature)
  146. let similarity = 1 - distance
  147. if similarity >= similarityThreshold {
  148. similarGroup.append(imageFeatures[j].asset)
  149. processedIndices.insert(j)
  150. }
  151. // 更新比较进度
  152. processedComparisons += 1
  153. let compareProgress = Float(processedComparisons) / Float(totalComparisons)
  154. progressHandler("正在比较相似度...", 0.6 + compareProgress * 0.4)
  155. } catch {
  156. print("相似度计算失败: \(error)")
  157. }
  158. }
  159. if similarGroup.count > 1 {
  160. similarGroups.append(similarGroup)
  161. }
  162. }
  163. // 按照照片数量降序排序
  164. similarGroups.sort { $0.count > $1.count }
  165. DispatchQueue.main.async {
  166. completion(similarGroups)
  167. }
  168. }
  169. }
  170. }
  171. // 按地点分类
  172. private func classifyByLocation(assets: PHFetchResult<PHAsset>,
  173. completion: @escaping ([String: [PHAsset]]) -> Void) {
  174. var locationGroups: [String: [PHAsset]] = [:]
  175. let group = DispatchGroup()
  176. let geocodeQueue = DispatchQueue(label: "com.app.geocoding")
  177. let semaphore = DispatchSemaphore(value: 10) // 限制并发请求数
  178. assets.enumerateObjects { asset, _, _ in
  179. if let location = asset.location {
  180. group.enter()
  181. semaphore.wait()
  182. geocodeQueue.async {
  183. let geocoder = CLGeocoder()
  184. geocoder.reverseGeocodeLocation(location) { placemarks, error in
  185. defer {
  186. semaphore.signal()
  187. group.leave()
  188. }
  189. if let placemark = placemarks?.first {
  190. let locationName = self.formatLocationName(placemark)
  191. DispatchQueue.main.async {
  192. if locationGroups[locationName] == nil {
  193. locationGroups[locationName] = []
  194. }
  195. locationGroups[locationName]?.append(asset)
  196. }
  197. }
  198. }
  199. }
  200. }
  201. }
  202. // 等待所有地理编码完成后回调
  203. group.notify(queue: .main) {
  204. completion(locationGroups)
  205. }
  206. }
  207. // 格式化地点名称(只返回城市名)
  208. private func formatLocationName(_ placemark: CLPlacemark) -> String {
  209. if let city = placemark.locality {
  210. return city
  211. }
  212. return "其他"
  213. }
  214. // 按人物分类
  215. private func classifyByPeople(assets: PHFetchResult<PHAsset>,
  216. completion: @escaping ([String: [PHAsset]]) -> Void) {
  217. var peopleGroups: [String: [PHAsset]] = [:]
  218. let group = DispatchGroup()
  219. // 创建一个数组来存储检测到人脸的照片
  220. var facesArray: [PHAsset] = []
  221. // 遍历所有照片
  222. assets.enumerateObjects { asset, _, _ in
  223. group.enter()
  224. // 获取照片的缩略图进行人脸检测
  225. let options = PHImageRequestOptions()
  226. options.isSynchronous = false
  227. options.deliveryMode = .fastFormat
  228. PHImageManager.default().requestImage(
  229. for: asset,
  230. targetSize: CGSize(width: 500, height: 500), // 使用较小的尺寸提高性能
  231. contentMode: .aspectFit,
  232. options: options
  233. ) { image, _ in
  234. guard let image = image else {
  235. group.leave()
  236. return
  237. }
  238. // 使用 Vision 框架检测人脸
  239. guard let ciImage = CIImage(image: image) else {
  240. group.leave()
  241. return
  242. }
  243. let request = VNDetectFaceRectanglesRequest()
  244. let handler = VNImageRequestHandler(ciImage: ciImage)
  245. do {
  246. try handler.perform([request])
  247. if let results = request.results, !results.isEmpty {
  248. // 检测到人脸,添加到数组
  249. DispatchQueue.main.async {
  250. facesArray.append(asset)
  251. }
  252. }
  253. } catch {
  254. print("人脸检测失败: \(error)")
  255. }
  256. group.leave()
  257. }
  258. }
  259. // 等待所有检测完成后更新结果
  260. group.notify(queue: .main) {
  261. if !facesArray.isEmpty {
  262. peopleGroups["包含人脸的照片"] = facesArray
  263. }
  264. completion(peopleGroups)
  265. }
  266. }
  267. // 识别截图
  268. private func fetchScreenshots(from assets: PHFetchResult<PHAsset>,
  269. completion: @escaping ([PHAsset]) -> Void) {
  270. var screenshots: [PHAsset] = []
  271. // 获取系统的截图智能相册
  272. let screenshotAlbums = PHAssetCollection.fetchAssetCollections(
  273. with: .smartAlbum,
  274. subtype: .smartAlbumScreenshots,
  275. options: nil
  276. )
  277. // 从截图相册中获取所有截图
  278. screenshotAlbums.enumerateObjects { collection, _, _ in
  279. let fetchOptions = PHFetchOptions()
  280. let screenshotAssets = PHAsset.fetchAssets(in: collection, options: fetchOptions)
  281. screenshotAssets.enumerateObjects { asset, _, _ in
  282. screenshots.append(asset)
  283. }
  284. }
  285. completion(screenshots)
  286. }
  287. }