XGBoost解析系列-数据加载
来源:互联网 发布:中国电信网络在线测速 编辑:程序博客网 时间:2024/06/05 07:58
- 前言
- XGBoost数据加载
- 1 DMatrixLoad主流程
- 2 解析器parser构建过程
- 3 DMatrix对象构建过程
0.前言
本文主要介绍XGBoost中数据加载过程,主要是DMatrix::Load
内容。
1. XGBoost数据加载
1.1 DMatrix::Load主流程
数据集加载语句为:
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(param.train_path, param.silent != 0, param.dsplit == 2));# 函数原型为:DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, const std::string& file_format)
可见,DMatrix::Load
返回为DMatrix
对象指针,参数传入为:1)文件URI
地址,2)silent
开关,为true打印统计信息,3)load_row_split
为分布式开关,为true,则对数据进行分片shard。4)file_format
为数据解析格式,默认参数为"auto"
,自动解析文件数据格式。基于mushroom.conf
配置, 提供训练数据地址:xgboost目录下demo/data/agaricus.txt.train
,数据为libsvm
格式。DMatrix::Load
主流程如下:
1. uri带有#符号,解析出cache_file
文件,考虑分布式模式的情况;
2. 分布式模式,获取当前主机排序partid
以及主机总数npart
,单机下partid=1;npart=1
;
3. 满足file_format == "auto" && npart == 1
,检测文件是否为二进制数据文件,检查开头魔方是否为SimpleCSRSource::kMagic
,若是则LoadBinary
直接初始化SimpleCSRSource
数据源对象source
。
4. 构建解析器parser
,基于工厂设计模式,基类dmlc::Parser<uint32_t>
调用静态方法Create()
, 尽管格式解析为”auto”,目前使用libsvm
格式解析。
5. 解析器parser
构建返回DMatrix
对象:dmat = DMatrix::Create(parser.get(), cache_file)
,核心过程,后续详解。
6. 根据参数slient打印相关统计信息,尝试读取.group
后缀的文件;.base_margin
后缀的boost初始设定值;.weight
样本的权重,用于代价敏感学习,不存在则跳过。
1.2 解析器parser构建过程
该小节先对步骤4进行详解,Parser<uint32_t>
继承于DataIter<RowBlock<uint32_t> >
,实际上内部数据以CSR格式
。基于静态方法Parser<uint32_t>::Create()
构建实例,该方法调用普通方法CreateParser_
方法。由于配置最后解析为libsvm
格式,使用ParserFactoryReg
工厂来找到对应的LibSVMParser
解析器对象。具体代码参考如下:
// 通过ptype=libsvm找到已注册的解析器工厂方法,Get()->Find()后面会看到const ParserFactoryReg<IndexType>* e = Registry<ParserFactoryReg<IndexType> >::Get()->Find(ptype);// 通过工厂方法生成解析器对象,最终调用的是CreateLibSVMParser<uint32_t>函数return (*e->body)(spec.uri, spec.args, part_index, num_parts);
理解上述的代码需要梳理以下的类定义与宏定义过程:
ParserFactoryReg解析器工厂类定义
// 工厂模式,通过宏定义来注册组件,继承FunctionRegEntryBase// 第1个类参数必须是本身类型,第2个类参数是工厂方法。// ParserFactoryReg会绑定工厂方法类型,通过工厂方法来组件对象,即Parser<IndexType>::Factor为函数类型// 后期会调用DMLC_REGISTRY_ENABLE实例生成静态单例工厂对象。宏定义实现可扩展的工厂模式+对象单例非常精彩template<typename IndexType>struct ParserFactoryReg : public FunctionRegEntryBase<ParserFactoryReg<IndexType>, typename Parser<IndexType>::Factory> {};
FunctionRegEntryBase工厂基类模板
// FunctionRegEntryBase方法注册类模板,需要注册项类型,方法类型template<typename EntryType, typename FunctionType>class FunctionRegEntryBase { public: std::string name; // 注册项名字 std::string description; // 注册项描述 std::vector<ParamFieldInfo> arguments; // factory function调用参数,ParamFieldInfo结构体描述参数 FunctionType body; // 函数方法体,body函数指针 std::string return_type; // 函数返回类型 EntryType set_body(); // 重要方法,设置工厂方法,并返回*(static_cast<EntryType*>(this)) EntryType *Find(const std::string &name); // 根据组件注册名字,找到注册项}// 注册get()方法模板特化,使用DMLC_REGISTRY_ENABLE(ParserFactoryReg<uint32_t>); 会构建Registry<ParserFactoryReg<uint32_t> >解析器工厂对象。// Registry是注册器类模板,所有注册器类都基于该模板,静态Get方法实现单例模式#define DMLC_REGISTRY_ENABLE(EntryType) \ template<> \ Registry<EntryType > *Registry<EntryType >::Get() { \ static Registry<EntryType > inst; \ return &inst; \ } \
Registry模板类定义
template<typename EntryType>class Registry { public: static Registry *Get(); // 静态Get()生成注册器对象,单例模式 private: std::vector<EntryType*> entry_list_; // 注册项列表 std::vector<const EntryType*> const_list_; // 注册项列表 std::map<std::string, EntryType*> fmap_; // 注册项map索引 EntryType &__REGISTER__(const std::string& name); // 注册方法,返回空白注册项}
DMLC注册宏定义
// 通过Get()得到工厂对象,调用__REGISTER__()获取空白注册项,需要被上层调用构建注册项#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ ::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \// DMLC数据解析器工厂方法注册,会调用上面DMLC_REGISTRY_REGISTER过程,得到注册项后进行设置工厂方法#define DMLC_REGISTER_DATA_PARSER(IndexType, TypeName, FactoryFunction) \DMLC_REGISTRY_REGISTER(::dmlc::ParserFactoryReg<IndexType>, \ ParserFactoryReg ## _ ## IndexType, TypeName) \.set_body(FactoryFunction)
libsvm解析器组件注册例子
// ./dmlc-core/src/data.cc中libsvm解析器组件注册:特征id为uint32_t类型,数据格式libsvm,工厂方法CreateLibSVMParser<uint32_t>。DMLC_REGISTER_DATA_PARSER(uint32_t, libsvm, data::CreateLibSVMParser<uint32_t>)// 将宏定义展开,等价于全局定义变量__make_ParserFactoryReg_uint32_t_libsvm__static __attribute__((unused)) ::dmlc::ParserFactoryReg<uint32_t> & __make_ParserFactoryReg_uint32_t_libsvm__ = \::dmlc::Registry<EntryType>::Get()->__REGISTER__("libsvm").set_body(data::CreateLibSVMParser<uint32_t>)
libsvm解析器工厂方法CreateLibSVMParser详解
template<typename IndexType>Parser<IndexType> *CreateLibSVMParser(const std::string& path, const std::map<std::string, std::string>& args, unsigned part_index, unsigned num_parts) { InputSplit* source = InputSplit::Create(path.c_str(), part_index, num_parts, "text"); // libsvm使用LibSVMParser解析 ParserImpl<IndexType> *parser = new LibSVMParser<IndexType>(source, 2); #if DMLC_ENABLE_STD_THREAD // 如果打开,则使用ThreadedParser进行包装,具备多线程加载功能 parser = new ThreadedParser<IndexType>(parser);#endif return parser;}
CreateLibSVMParser()
生成LibSVMParser
前会构建数据分片InputSplit
, 又是基于工厂模式,基类InputSplit
调用静态方法Create()
实例化,由于使用”text”数据类型,调用LineSplitter();
生成分片处理实例对象。
LineSplitter
继承于InputSplitBase
,初始化会调用InputSplitBase
中的Init
与ResetPartition
方法,主要完成对数据分片的功能:
1. Init
获取所有文件以及大小,生成对应数组变量file_offset_
,file_offset_[i]
表示文件i开始数据的字节偏移量。
2. ResetPartition
根据当前主机的rank值与所有主机数,由于数据分片采用均分方式,获取当前主机分片下的数据起始位置offset_begin_、offset_end_
,通过std::upper_bound
二分查找file_offset_定位跨度的文件列表索引访问file_ptr_、file_ptr_end_
,期间会有数据align对齐逻辑,并通过BeforeFirst
初始化文件读取状态和初始化内部变量,把文件指针挪到开始位置。
1.3 DMatrix对象构建过程
生成parser之后需要调用解析器不断获取数据生成DMatrix内存数据对象,对应主流程步骤5,具体过程如下:
1. 初始化空数据源SimpleCSRSource
对象source
,继承于类DataSource
,再上层类为dmlc::DataIter<RowBatch>
。主要初始化成员row_ptr_
指针。SimpleCSRSource
成员如下:
class SimpleCSRSource : public DataSource { public: // MetaInfo info; // 继承于DataSource,单机全局元信息 std::vector<size_t> row_ptr_; // 单机全局下,CSR行偏移 std::vector<RowBatch::Entry> row_data_; // 单机全局下,CSR稀疏存储内存块,以bacth vector方式存储 RowBatch batch_; // 调用Next会将数据指针绑定到该变量下 bool at_first_; // 开始标记,BeforeFirst会设置true, 首次调用Next设置成false。再调用Next会退出,也就是说只有一次Next调用执行,所以说Simple bool Next(); // 直接将全量数据的指针赋值给RowBatch batch_; const RowBatch &Value(); // 返回batch_变量引用}
了解整个数据加载过程,以下3个类非常重要:
struct SparseBatch { // 稀疏项描述 struct Entry { bst_uint index; // 稀疏项索引 bst_float fvalue; // 稀疏项数值 Entry() {} Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {} // 稀疏项比较,特征值排序会用到 inline static bool CmpValue(const Entry& a, const Entry& b) { return a.fvalue < b.fvalue; } }; // batch中稀疏向量,如果把一个instance看做一行,特征项为稀疏项,一行是一个稀疏向量 struct Inst { const Entry *data; // 稀疏向量数据块指针 bst_uint length; // 稀疏项个数 Inst() : data(0), length(0) {} Inst(const Entry *data, bst_uint length) : data(data), length(length) {} inline const Entry& operator[](size_t i) const { // 访问第i个稀疏项 return data[i]; } }; size_t size; // batch中的稀疏向量数};// 基于CSR数据格式,稀疏存储。类似地,ColBatch是基于CSC数据格式。struct RowBatch : public SparseBatch { size_t base_rowid; // 每个批次的rowid偏移 const size_t *ind_ptr; // 偏移数组,大小为size+1,ind_ptr[i]表示第i行开始数据的偏移位置 const Entry *data_ptr; // 批次数据块的指针 // batch中第i行的数据 inline Inst operator[](size_t i) const { return Inst(data_ptr + ind_ptr[i], static_cast<bst_uint>(ind_ptr[i + 1] - ind_ptr[i])); }};
2. source->CopyFrom(parser);
调用解析器从文件中不断批次解析出数据,存入变量row_data_
中,核心逻辑如下:
while (parser->Next()) { // 不断获取batch数据,并判断是否结束 const dmlc::RowBlock<uint32_t>& batch = parser->Value(); // 得到批次数据 if (batch.label != nullptr) { // 解析得到label值 info.labels.insert(info.labels.end(), batch.label, batch.label + batch.size); } if (batch.weight != nullptr) { // 解析得到weight值 info.weights.insert(info.weights.end(), batch.weight, batch.weight + batch.size); } for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) { uint32_t index = batch.index[i]; bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i]; row_data_.push_back(SparseBatch::Entry(index, fvalue)); // 将值存入row_data_变量 this->info.num_col = std::max(this->info.num_col, static_cast<uint64_t>(index + 1)); }}
解析器对象parser
基于ThreadedParser
,包装了LibSVMParser
,两者都基于ParserImpl
, 成员如下:
template <typename IndexType>class ThreadedParser : public ParserImpl<IndexType> { Parser<IndexType> *base_; // 绑定LibSVMParser,真正数据解析器 ThreadedIter<std::vector<RowBlockContainer<IndexType> > > iter_; // 后端线程迭代器 std::vector<RowBlockContainer<IndexType> > *tmp_; //当前数据chunk};
ThreadedIter
基于DataIter
,需要实现BeforeFirst()
,Next()
, Value()
方法, 使用线程构建消费者与生产者模式,生产者预取数据。生产者内部调用ParseNext()
解析数据,核心调用FillData()
:
1)InputSplit
对象调用NextChunk
,每次最多读取16MB数据到Chunk
,再转化到Blob
对象中,内部使用ReadChunk
将文件数据读入到内存Chunk
中,使用overflow_
缓存分片位置到记录结束位置以便调整下一次文件指针,因此,batch最后读出的记录是完整的。
// 内存blob,只保留数据指针和大小,真正的数据hold在Chunkstruct Blob { void *dptr; // 内存blob指针 size_t size; // 内存blob大小};// 内存Chunk,保留着真正的数据struct Chunk { char *begin; char *end; std::vector<size_t> data; // 每次数据最多读入16MB explicit Chunk(size_t buffer_size) : begin(NULL), end(NULL), data(buffer_size + 1) {} bool Load(InputSplitBase *split, size_t buffer_size);};
2)使用OMP多线程并行调用ParseBlock
解析数据到具备vector
变量data
,基于数据均分定位线程处理的数据起始位置,为了保证边际在记录中的情况,需要BackFindEndLine
做指针相应偏移。ParseBlock
根据行偏移开始与行结束位置,将数据解析到线程对应RowBlockContainer<IndexType>
对象(*data)[tid]
中,ParsePair
底层函数会解析出weight、label、特征value,根据libsvm的冒号:
进行token分割。 RowBlockContainer
模板类定义如下,数据存储格式为CSR格式。
template<typename IndexType>struct RowBlockContainer { std::vector<size_t> offset; // offset[i]为第i行非0特征数据的偏移 std::vector<real_t> label; // offset[i]为第i行的label值 std::vector<real_t> weight; // offset[i]为第i行的weight值 std::vector<IndexType> field; // 特征field值 std::vector<IndexType> index; // 特征id值 std::vector<real_t> value; // 特征数值,需结合offset IndexType max_field; // 最大field值 IndexType max_index; // 最大特征id值}// RowBlock成员指针指向对应RowBlockContainer成员vector内存起始位置template<typename IndexType>struct RowBlock { size_t size; // 批次大小 const size_t *offset; // offset数据块指针 const real_t *label; // label数据块指针 const real_t *weight; // weight数据块指针 const IndexType *field; // field数据块指针 const IndexType *index; // index数据块指针 const real_t *value; // value数据块指针};
ThreadedParser
对象调用Next()
生成RowBlock<IndexType>
数据,即通过获取RowBlockContainer<IndexType>
数据的内存初始位置。如果消费完预取数据,则iter_.Next(&tmp_)
继续解析数据到tmp_中。每次parser->Value()
生成batch数据,更新MetaInfo
信息labels、weights、num_nonzeronum_row
。此外,稀疏特征数据项SparseBatch::Entry
插入row_data_
,更新对应row_ptr_
偏移。可以说row_data_、row_ptr_
是最终的全局数据。
3. 使用解析器中的数据源对象source构造返回SimpleDMatrix
对象,基类DMatrix
。SparsePageDMatrix
对象需要结合cache_prefix
文件,目前配置并不支持,以后再讲。至此,SimpleDMatrix
提供RowIterator
按照行遍历实例数据的能力,即CSR数据存储,SimpleDMatrix
调用Next()
直接绑定全量数据,这也是为什么叫做Simple的原因了。后续会将CSR转化为CSC格式,提供ColBatchIter
遍历器,主要利用ParallelGroupBuilder
对象多线程并行实现,具体细节详见此处LazyInitDMatrix过程。
- XGBoost解析系列-数据加载
- XGBoost解析系列-准备
- XGBoost解析系列-原理
- XGBoost解析系列--源码主流程
- XGBoost原理解析
- 数据加载的妙招解析
- feature_names mismatch XGBoost错误解析
- 学习笔记:XGBoost原理解析
- feature_names mismatch XGBoost错误解析
- Android系列---JSON数据解析
- Android系列---JSON数据解析
- Android系列---JSON数据解析
- Android系列---JSON数据解析
- Android系列---JSON数据解析
- Android系列---JSON数据解析
- Android系列---JSON数据解析
- Android系列---JSON数据解析
- Android系列---JSON数据解析
- VMWare Linux虚拟机设置固定IP上网方法(靠谱)
- kvm qemu vhost-user
- 为VMware虚拟机中的Linux系统设置固定IP的方法
- iptables& firewalld
- java web mybatis 查询慢
- XGBoost解析系列-数据加载
- LRTimelapse Pro 4.8.3 Windows / Mac 简体中文 延时摄影处理软件
- 小而巧的bootstrap-wysiwyg 可以将任何一个div变成富文本编辑器
- 线程同步1 ------ 互斥锁
- 算符优先分析法
- 刷机软件-Duilib界面
- 机器学习---决策树算法
- JS用法之Object妙用一
- 音速启动-软件工具箱制作