caffe源码分析--Blob类

来源:互联网 发布:淘宝男装销量排行榜 编辑:程序博客网 时间:2024/05/17 05:10

转自:http://blog.csdn.net/lingerlanlan/article/details/24379689

  • 数据成员
     protected:  shared_ptr<SyncedMemory> data_;  //data数据,指向SyncedMemory类的智能指针  shared_ptr<SyncedMemory> diff_;   //参数更新量  shared_ptr<SyncedMemory> shape_data_;  //数据维度  vector<int> shape_;  //数据维度  int count_;    //数据量  int capacity_;  //数据量
  • 构造函数
Blob(): data_(), diff_(), count_(0), capacity_(0){}
explicit Blob(const int num, const int channels, const int height, const int width);
Blob<Dtype>::Blob(const vector<int>& shape)//一般用这个  : capacity_(0) {  Reshape(shape);}
template <typename Dtype>void Blob<Dtype>::Reshape(const vector<int>& shape) {  CHECK_LE(shape.size(), kMaxBlobAxes);  count_ = 1;  shape_.resize(shape.size());  if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {    shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));  }  int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());  for (int i = 0; i < shape.size(); ++i) {    CHECK_GE(shape[i], 0);    CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";    count_ *= shape[i];    shape_[i] = shape[i];    shape_data[i] = shape[i];  }  if (count_ > capacity_) {    capacity_ = count_;    data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));    diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));  }}
  void Reshape(const BlobShape& shape);  void ReshapeLike(const Blob& other);
  • 其它函数
  inline const vector<int>& shape() const { returnshape_; }  inline int shape(int index) const {    return shape_[CanonicalAxisIndex(index)];  }  inline int num_axes() const { return shape_.size(); }  inline int count() const { return count_; }
  inline int count(int start_axis, int end_axis) const {} //返回start轴到end轴的数据量,区间左闭右开
//这四个函数过时了,使用shape(i)吧  inline int num() const  inline int channels() const  inline int height() const  inline int width() const 
// 返回偏移量  inline int offset(const int n, const int c = 0, const int h = 0, const int w = 0)  inline int offset(const vector<int>& indices) const //用这个
//拷贝source数据  template <typename Dtype>void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {  if (source.count() != count_ || source.shape() != shape_) {    if (reshape) {      ReshapeLike(source);    } else {      LOG(FATAL) << "Trying to copy blobs of different sizes.";    }  }  switch (Caffe::mode()) {  case Caffe::GPU:    if (copy_diff) {//copy_diff为真,则拷贝diff;否则拷贝data      caffe_copy(count_, source.gpu_diff(),          static_cast<Dtype*>(diff_->mutable_gpu_data()));    } else {      caffe_copy(count_, source.gpu_data(),          static_cast<Dtype*>(data_->mutable_gpu_data()));    }    break;  case Caffe::CPU:    if (copy_diff) {      caffe_copy(count_, source.cpu_diff(),          static_cast<Dtype*>(diff_->mutable_cpu_data()));    } else {      caffe_copy(count_, source.cpu_data(),          static_cast<Dtype*>(data_->mutable_cpu_data()));    }    break;  default:    LOG(FATAL) << "Unknown caffe mode.";  }}
<pre name="code" class="cpp">//写入bolbtemplate <>void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {  proto->clear_shape();  for (int i = 0; i < shape_.size(); ++i) {    proto->mutable_shape()->add_dim(shape_[i]);  }  proto->clear_double_data();  proto->clear_double_diff();  const double* data_vec = cpu_data();  for (int i = 0; i < count_; ++i) {    proto->add_double_data(data_vec[i]);  }  if (write_diff) {    const double* diff_vec = cpu_diff();    for (int i = 0; i < count_; ++i) {      proto->add_double_diff(diff_vec[i]);    }  }
//访问(取)数据  inline Dtype data_at(const int n, const int c, const int h, const int w) const  inline Dtype diff_at(const int n, const int c, const int h, const int w) const  inline Dtype data_at(const vector<int>& index) const   inline Dtype diff_at(const vector<int>& index) 
 inline const shared_ptr<SyncedMemory>& data() const  //return data_智能指针  inline const shared_ptr<SyncedMemory>& diff() const //return diff_智能指针
const int* Blob<Dtype>::gpu_shape() const //return (const int*)shape_data_->gpu_data()const Dtype* Blob<Dtype>::cpu_data() const //return (const Dtype*)data_->cpu_data()const Dtype* Blob<Dtype>::gpu_data() //return (const Dtype*)data_->gpu_data()const Dtype* Blob<Dtype>::cpu_diff() const //return (const Dtype*)diff_->cpu_data()const Dtype* Blob<Dtype>::gpu_diff() // return (const Dtype*)diff_->gpu_data()Dtype* Blob<Dtype>::mutable_cpu_data() //return static_cast<Dtype*>(data_->mutable_cpu_data())Dtype* Blob<Dtype>::mutable_gpu_data() //return static_cast<Dtype*>(data_->mutable_gpu_data())Dtype* Blob<Dtype>::mutable_cpu_diff() // return static_cast<Dtype*>(diff_->mutable_cpu_data())Dtype* Blob<Dtype>::mutable_gpu_diff()  //return static_cast<Dtype*>(diff_->mutable_gpu_data())
//将other的data_和diff_赋给blobvoid Blob<Dtype>::ShareData(const Blob& other)void Blob<Dtype>::ShareDiff(const Blob& other)
<pre name="code" class="cpp">//更新权重
<pre name="code" class="cpp">template <typename Dtype>void Blob<Dtype>::Update() {  // We will perform update based on where the data is located.  switch (data_->head()) {  case SyncedMemory::HEAD_AT_CPU:    // perform computation on CPU    caffe_axpy<Dtype>(count_, Dtype(-1),        static_cast<const Dtype*>(diff_->cpu_data()),        static_cast<Dtype*>(data_->mutable_cpu_data()));    break;  case SyncedMemory::HEAD_AT_GPU:  case SyncedMemory::SYNCED:#ifndef CPU_ONLY    // perform computation on GPU    caffe_gpu_axpy<Dtype>(count_, Dtype(-1),        static_cast<const Dtype*>(diff_->gpu_data()),        static_cast<Dtype*>(data_->mutable_gpu_data()));#else    NO_GPU;#endif    break;  default:    LOG(FATAL) << "Syncedmem not initialized.";  }}
 Dtype asum_data() const; //返回data的第一范数  Dtype asum_diff() const; //返回diff的第一范数  Dtype sumsq_data() const; //返回data的第二范数  Dtype sumsq_diff() const; //返回diff的第二范数//放缩data和diff  void scale_data(Dtype scale_factor);  void scale_diff(Dtype scale_factor);
bool ShapeEquals(const BlobProto& other); //判断各维是否相等











0 0
原创粉丝点击