CNN(卷积神经网络)在iOS上的使用

来源:互联网 发布:代挂外包源码 编辑:程序博客网 时间:2024/06/04 18:40

在iOS11上推出了CoreML和架构在CoreML之上的Vision, 这样为CNN(卷积神经网络)在iOS设备上的应用铺平了道路。

将CoreML模型加载到App

让你的App集成CoreML模型非常简单, 将模型文件(*.mlmodel)拖进工程即可. 在Xcode中可以看到此模型的描述.
这里写图片描述
Xcode可以为此模型文件自动生成一个可以被使用的对象, 此预测人年龄的CNN的自动生成代码如下(Swift)

//// AgeNet.swift//// This file was automatically generated and should not be edited.//import CoreML/// Model Prediction Input Type@available(OSX 13.0, iOS 11.0, tvOS 11.0, watchOS 4.0, *)class AgeNetInput : MLFeatureProvider {    /// An image with a face. as color (kCVPixelFormatType_32BGRA) image buffer, 227 pixels wide by 227 pixels high    var data: CVPixelBuffer    var featureNames: Set<String> {        get {            return ["data"]        }    }    func featureValue(for featureName: String) -> MLFeatureValue? {        if (featureName == "data") {            return MLFeatureValue(pixelBuffer: data)        }        return nil    }    init(data: CVPixelBuffer) {        self.data = data    }}/// Model Prediction Output Type@available(OSX 13.0, iOS 11.0, tvOS 11.0, watchOS 4.0, *)class AgeNetOutput : MLFeatureProvider {    /// The probabilities for each age, for the given input. as dictionary of strings to doubles    let prob: [String : Double]    /// The most likely age, for the given input. as string value    let classLabel: String    var featureNames: Set<String> {        get {            return ["prob", "classLabel"]        }    }    func featureValue(for featureName: String) -> MLFeatureValue? {        if (featureName == "prob") {            return try! MLFeatureValue(dictionary: prob as [NSObject : NSNumber])        }        if (featureName == "classLabel") {            return MLFeatureValue(string: classLabel)        }        return nil    }    init(prob: [String : Double], classLabel: String) {        self.prob = prob        self.classLabel = classLabel    }}/// Class for model loading and prediction@available(OSX 13.0, iOS 11.0, tvOS 11.0, watchOS 4.0, *)class AgeNet {    var model: MLModel    /**        Construct a model with explicit path to mlmodel file        - parameters:           - url: the file url of the model           - throws: an NSError object that describes the problem    */    init(contentsOf url: URL) throws {        self.model = try MLModel(contentsOf: url)    }    /// Construct a model that automatically loads the model from the app's bundle    convenience init() {        let bundle = Bundle(for: AgeNet.self)        let assetPath = bundle.url(forResource: "AgeNet", withExtension:"mlmodelc")        try! self.init(contentsOf: assetPath!)    }    /**        Make a prediction using the structured interface        - parameters:           - input: the input to the prediction as AgeNetInput        - throws: an NSError object that describes the problem        - returns: the result of the prediction as AgeNetOutput    */    func prediction(input: AgeNetInput) throws -> AgeNetOutput {        let outFeatures = try model.prediction(from: input)        let result = AgeNetOutput(prob: outFeatures.featureValue(for: "prob")!.dictionaryValue as! [String : Double], classLabel: outFeatures.featureValue(for: "classLabel")!.stringValue)        return result    }    /**        Make a prediction using the convenience interface        - parameters:            - data: An image with a face. as color (kCVPixelFormatType_32BGRA) image buffer, 227 pixels wide by 227 pixels high        - throws: an NSError object that describes the problem        - returns: the result of the prediction as AgeNetOutput    */    func prediction(data: CVPixelBuffer) throws -> AgeNetOutput {        let input_ = AgeNetInput(data: data)        return try self.prediction(input: input_)    }}

加载CNN, 并且创建分析请求(Image Analysis Request)

let ageModel = AgeNet()func setupVision() {        guard let vnAgeModel = try? VNCoreMLModel(for: ageModel.model) else {            NSLog("Load age model fail")            return        }        ageRequest = VNCoreMLRequest(model: vnAgeModel, completionHandler: { (request : VNRequest, error : Error? ) in            //NSLog("VNCoreML Request complete")            if let observations = request.results as? [VNClassificationObservation] {                if( observations.count > 1  && observations[0].confidence > 0.5 ){                    DispatchQueue.main.async {                        self.mInfo.text = "Your age is " + observations[0].identifier + "/" + String(observations[0].confidence)                    }                }            }            return        })        ageRequest?.imageCropAndScaleOption = .scaleFit       }

执行分析

func predict(pixelBuffer : CVPixelBuffer) {        let handler = VNImageRequestHandler(cvPixelBuffer: pixelBuffer)        try? handler.perform([ageRequest])        let genderHandler = VNImageRequestHandler(cvPixelBuffer: pixelBuffer)        try? genderHandler.perform([genderRequest])    }

完整代码在此:
https://github.com/volvet/GuessingAge

Reference

https://developer.apple.com/documentation/coreml/integrating_a_core_ml_model_into_your_app
https://github.com/SwiftBrain/awesome-CoreML-models

原创粉丝点击