tiny-dnn源码Edge类

来源:互联网 发布:淘宝二级页面是什么 编辑:程序博客网 时间:2024/06/08 17:58
class edge { public:    edge(node* prev, const shape3d& shape, vector_type vtype)        : shape_(shape),          vtype_(vtype),          data_({vec_t(shape.size())}),          grad_({ vec_t(shape.size()) }),          prev_(prev)    {    }    void merge_grads(vec_t *dst) {        dst->resize(grad_[0].size());        std::fill(dst->begin(), dst->end(), static_cast<float_t>(0));        // @todo consider adding parallelism        for (cnn_size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) {            vectorize::reduce<float_t>(&grad_[sample][0], dst->size(), &(*dst)[0]);        }    }    void clear_grads() {        for (cnn_size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) {            std::fill(grad_[sample].begin(), grad_[sample].end(), (float_t)0);        }    }    tensor_t* get_data() {        return &data_;    }    const tensor_t* get_data() const {        return &data_;    }    tensor_t* get_gradient() {        return &grad_;    }    const tensor_t* get_gradient() const {        return &grad_;    }    const std::vector<node*>& next() const { return next_; }    node* prev() { return prev_; }    const node* prev() const { return prev_; }    const shape3d& shape() const { return shape_; }    vector_type vtype() const { return vtype_; }    void add_next_node(node* next) { next_.push_back(next); } private:    shape3d shape_;    vector_type vtype_;    tensor_t data_;    tensor_t grad_;    node* prev_;               // previous node, "producer" of this tensor    std::vector<node*> next_;  // next nodes, "consumers" of this tensor};
class node : public std::enable_shared_from_this<node> {public:    node(cnn_size_t in_size, cnn_size_t out_size)        : prev_(in_size), next_(out_size) {}    virtual ~node() {}    const std::vector<edgeptr_t>& prev() const { return prev_; }    const std::vector<edgeptr_t>& next() const { return next_; }    cnn_size_t prev_port(const edge& e) const {        auto it = std::find_if(prev_.begin(), prev_.end(),                               [&](edgeptr_t ep) { return ep.get() == &e; });        return (cnn_size_t)std::distance(prev_.begin(), it);    }    cnn_size_t next_port(const edge& e) const {        auto it = std::find_if(next_.begin(), next_.end(),                               [&](edgeptr_t ep) { return ep.get() == &e; });        return (cnn_size_t)std::distance(next_.begin(), it);    }    std::vector<node*> prev_nodes() const; // @todo refactor and remove this method    std::vector<node*> next_nodes() const; // @todo refactor and remove this method protected:    node() = delete;    friend void connect(layerptr_t head, layerptr_t tail,                        cnn_size_t head_index, cnn_size_t tail_index);    mutable std::vector<edgeptr_t> prev_;    mutable std::vector<edgeptr_t> next_;};
原创粉丝点击