tiny-dnn源码Edge类
来源:互联网 发布:淘宝二级页面是什么 编辑:程序博客网 时间:2024/06/08 17:58
class edge { public: edge(node* prev, const shape3d& shape, vector_type vtype) : shape_(shape), vtype_(vtype), data_({vec_t(shape.size())}), grad_({ vec_t(shape.size()) }), prev_(prev) { } void merge_grads(vec_t *dst) { dst->resize(grad_[0].size()); std::fill(dst->begin(), dst->end(), static_cast<float_t>(0)); // @todo consider adding parallelism for (cnn_size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) { vectorize::reduce<float_t>(&grad_[sample][0], dst->size(), &(*dst)[0]); } } void clear_grads() { for (cnn_size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) { std::fill(grad_[sample].begin(), grad_[sample].end(), (float_t)0); } } tensor_t* get_data() { return &data_; } const tensor_t* get_data() const { return &data_; } tensor_t* get_gradient() { return &grad_; } const tensor_t* get_gradient() const { return &grad_; } const std::vector<node*>& next() const { return next_; } node* prev() { return prev_; } const node* prev() const { return prev_; } const shape3d& shape() const { return shape_; } vector_type vtype() const { return vtype_; } void add_next_node(node* next) { next_.push_back(next); } private: shape3d shape_; vector_type vtype_; tensor_t data_; tensor_t grad_; node* prev_; // previous node, "producer" of this tensor std::vector<node*> next_; // next nodes, "consumers" of this tensor};
class node : public std::enable_shared_from_this<node> {public: node(cnn_size_t in_size, cnn_size_t out_size) : prev_(in_size), next_(out_size) {} virtual ~node() {} const std::vector<edgeptr_t>& prev() const { return prev_; } const std::vector<edgeptr_t>& next() const { return next_; } cnn_size_t prev_port(const edge& e) const { auto it = std::find_if(prev_.begin(), prev_.end(), [&](edgeptr_t ep) { return ep.get() == &e; }); return (cnn_size_t)std::distance(prev_.begin(), it); } cnn_size_t next_port(const edge& e) const { auto it = std::find_if(next_.begin(), next_.end(), [&](edgeptr_t ep) { return ep.get() == &e; }); return (cnn_size_t)std::distance(next_.begin(), it); } std::vector<node*> prev_nodes() const; // @todo refactor and remove this method std::vector<node*> next_nodes() const; // @todo refactor and remove this method protected: node() = delete; friend void connect(layerptr_t head, layerptr_t tail, cnn_size_t head_index, cnn_size_t tail_index); mutable std::vector<edgeptr_t> prev_; mutable std::vector<edgeptr_t> next_;};
阅读全文
0 0
- tiny-dnn源码Edge类
- 【卷积神经网络】tiny-dnn环境配置
- 【卷积神经网络】tiny-dnn网络参数
- 【卷积神经网络】tiny-dnn环境配置
- tiny-dnn import caffe's model
- 【卷积神经网络】tiny-dnn网络参数
- DNN源码安装方法
- 深度学习开源库tiny-dnn的使用(MNIST)
- 深度学习开源库tiny-dnn的使用(MNIST)
- 使用DotNetNuke(DNN)源码安装
- DNN
- EDGE
- Edge
- EDGE
- Edge
- EDGE
- EDGE
- Edge
- jsp基础
- Java&Android零碎的知识点2
- Btree索引详解
- Java应用Tomcat执行过程之性能调优
- HashMap,LinkedHashMap,TreeMap对比
- tiny-dnn源码Edge类
- Python图表绘制:matplotlib绘图库入门
- codeforces 454A Little Pony and Crystal Mine
- SpringBoot整合MyBatis
- Opengl RC(Render context,渲染上下文)与像素格式(转)
- Class字节码指令解释执行
- mysql常见面试题汇总
- 进程间通信(三)信号
- 位运算(ctrl c,v的)