神经网络识别手写优化(三)

来源:互联网 发布:ios虚拟定位软件下载 编辑:程序博客网 时间:2024/06/03 13:12

前言

本文是为了实现存储自己训练好的模型 结构和参数,以及加载训练好的模型进行预测。

代码

保存

    def save(self,filename):        """        模型保存        :param filename: 文件名         :return:         """        data ={ "sizes": self.sizes, #模型结构                "weights": [w.tolist() for w in self.weights], #tolist转换为列表类型                "biases": [b.tolist() for b in self.biases],                "cost": str(self.cost.__name__) #保存一下损失函数        }        f=open(filename,"w")        json.dump(data,f)        f.close()

加载

def load(filename):    """    加载模型    :param filename:     :return:     """    f=open(filename,"r")    data=json.load(f)    f.close()    cost=getattr(sys.modules[__name__],data["cost"]) #找对象    net=Network(data["sizes"],cost=cost)    net.weights=[np.array(w) for w in data["weights"]]    net.biases=[np.array(b) for b in data["biases"]]    return net
阅读全文
0 0
原创粉丝点击