西储大学轴承故障数据下载和整理

来源:互联网 发布:电容触摸按键编程 编辑:程序博客网 时间:2024/04/25 09:42

西储大学轴承故障数据下载和整理



西储大学轴承数据介绍

西储大学轴承数据官网

python2 版本 cwru库github地址

Github数据下载地址

python2 可以直接安装使用CWRU库,该库的功能是下载数据,并且切分成可供训练和评估的训练集和测试集数据

python2安装方式:

pip安装:

$ pip install --user cwru

github下载源代码安装:

$ python setup.py install

使用

import cwrudata = cwru.CWRU("12DriveEndFault", "1797", 384)

可以使用data.X_traindata.y_traindata.X_testdata.y_testdata.labelsdata.nclasses 来训练和评估模型

CWRU的参数:

exp:'12DriveEndFault', '12FanEndFault', '48DriveEndFault'

rpm:'1797', '1772', '1750', '1730'

length:信号的长度


python3版本:

由于python3 和python2 版本的差异,对原python2代码修改:

import osimport globimport errnoimport randomimport urllib.request as urllibimport numpy as npfrom scipy.io import loadmatclass CWRU:    def __init__(self, exp, rpm, length):        if exp not in ('12DriveEndFault', '12FanEndFault', '48DriveEndFault'):            print("wrong experiment name: {}".format(exp))            exit(1)        if rpm not in ('1797', '1772', '1750', '1730'):            print("wrong rpm value: {}".format(rpm))            exit(1)        # root directory of all data        rdir = os.path.join('Datasets/CWRU',                            exp,                            rpm)        print(rdir)        fmeta = os.path.join(os.path.dirname(__file__), 'metadata.txt')        all_lines = open(fmeta).readlines()        lines = []        for line in all_lines:            l = line.split()            if (l[0] == exp or l[0] == 'NormalBaseline') and l[1] == rpm:                lines.append(l)        self.length = length  # sequence length        self._load_and_slice_data(rdir, lines)        # shuffle training and test arrays        self._shuffle()        self.labels = tuple(line[2] for line in lines)        self.nclasses = len(self.labels)  # number of classes    def _mkdir(self, path):        try:            os.makedirs(path)        except OSError as exc:            if exc.errno == errno.EEXIST and os.path.isdir(path):                pass            else:                print("can't create directory '{}''".format(path))                exit(1)    def _download(self, fpath, link):        print("Downloading to: '{}'".format(fpath))        urllib.URLopener().retrieve(link, fpath)    def _load_and_slice_data(self, rdir, infos):        self.X_train = np.zeros((0, self.length))        self.X_test = np.zeros((0, self.length))        self.y_train = []        self.y_test = []        for idx, info in enumerate(infos):            # directory of this file            fdir = os.path.join(rdir, info[0], info[1])            self._mkdir(fdir)            fpath = os.path.join(fdir, info[2] + '.mat')            if not os.path.exists(fpath):                self._download(fpath, info[3].rstrip('\n'))            mat_dict = loadmat(fpath)            # key = filter(lambda x: 'DE_time' in x, mat_dict.keys())[0]            fliter_i = filter(lambda x: 'DE_time' in x, mat_dict.keys())            fliter_list = [item for item in fliter_i]            key = fliter_list[0]            time_series = mat_dict[key][:, 0]            idx_last = -(time_series.shape[0] % self.length)            clips = time_series[:idx_last].reshape(-1, self.length)            n = clips.shape[0]            n_split =int((3 * n / 4))            self.X_train = np.vstack((self.X_train, clips[:n_split]))            self.X_test = np.vstack((self.X_test, clips[n_split:]))            self.y_train += [idx] * n_split            self.y_test += [idx] * (clips.shape[0] - n_split)    def _shuffle(self):        # shuffle training samples        index = list(range(self.X_train.shape[0]))        random.Random(0).shuffle(index)        self.X_train = self.X_train[index]        self.y_train = tuple(self.y_train[i] for i in index)        # shuffle test samples        index = list(range(self.X_test.shape[0]))        random.Random(0).shuffle(index)        self.X_test = self.X_test[index]        self.y_test = tuple(self.y_test[i] for i in index)




原创粉丝点击