修改hadoop的DBInputFormat类,实现一次读取多个数据库

来源:互联网 发布:北京云计算招聘 编辑:程序博客网 时间:2024/05/21 12:43
原理:hadoop提供了DBInputFormat类读取关系型数据库的,它是把一张表分成若干个inputSplit作为map输入,需要定义一个类实现Writable, DBWritable接口,作为map方法的输入值参数,一条一条记录的读取表的信息。根据以上原理,如果要一次读取多个数据库的表,就要把每个库作为一个inputSplit,作为map的输入。做这里之前,必须清楚这个job执行的流程,下图是我画的草图:


所以,要想读取多台数据库,需要修改的地方:

1) DBConfiguration类里的configureDB()方法,这个是初始化数据库连接的一些参数,是全局性的,我要连接的是六个库,所以全局的变量不能要。
2)DBInputFormat类里的setInput方法也需要修改
3)DBInputFormat里的getsplit()方法,这个是最重要的地方,这是分块的方法,没修改之前的是用表里的所有数据行数来分块的,我要的是六个库的连接。

其他的方法也要改,只是相比起来和上面的三个方法就要简单些,上面的三个方法都是初始化的时候调用的,必须想好怎么改。

不罗嗦了,上代码:

  public static void configureDB(JobConf job, String driverClass, String dbUrl
      , String userName, String passwd) {

//    job.set(DRIVER_CLASS_PROPERTY, driverClass);
//    job.set(URL_PROPERTY, dbUrl);
//    if(userName != null)
//      job.set(USERNAME_PROPERTY, userName);
//    if(passwd != null)
//      job.set(PASSWORD_PROPERTY, passwd);    
  }


    public static void setInput(JobConf job,
            Class<? extends DBWritable> inputClass, List<Object> conditions) {
        job.setInputFormat(DBInputFormat.class);

        DBConfiguration dbConf = new DBConfiguration(job);
        dbConf.setInputClass(inputClass);

        DBInputFormat.DBNodes = conditions.size();
        objs = conditions;
    }

    public static int DBNodes = 0;
    public static List<Object> objs = null;

public InputSplit[] getSplits(JobConf job, int chunks) throws IOException {

        InputSplit[] splits = new InputSplit[DBNodes];
        long tmpstart = 0l;
        long tmpend = 0l;

        for (int i = 0; i < DBNodes; i++) {
            DBInputSplit split;

            String tmpurl = "";
            String tmpsql = "";
            String tmppasswd = "";
            String tmpdriverclass = "";
            String tmpusername = "";
//               job.set(DRIVER_CLASS_PROPERTY, driverClass);
//                job.set(URL_PROPERTY, dbUrl);
//                if(userName != null)
//                  job.set(USERNAME_PROPERTY, userName);
//                if(passwd != null)
//                  job.set(PASSWORD_PROPERTY, passwd);            
            Object dto = objs.get(i);
            
            Field[] fields = dto.getClass().getDeclaredFields();
            for (Field field : fields) {

                field.setAccessible(true);
                
                try {

                    if (field.get(dto) != null) {

                        if (field.getName().equalsIgnoreCase("url")) {
                            
                            tmpurl = String.valueOf(field.get(dto));
                            job.set(DBConfiguration.URL_PROPERTY, tmpurl);
                        }
                        if (field.getName().equalsIgnoreCase("sql")) {

                            tmpsql=String.valueOf(field.get(dto));
                        }
                        if (field.getName().equalsIgnoreCase("passwd")) {

                            tmppasswd=String.valueOf(field.get(dto));
                            job.set(DBConfiguration.PASSWORD_PROPERTY, tmppasswd);
                        }
                        if (field.getName().equalsIgnoreCase("driverclass")) {

                            tmpdriverclass=String.valueOf(field.get(dto));
                            job.set(DBConfiguration.DRIVER_CLASS_PROPERTY, tmpdriverclass);
                        }
                        if (field.getName().equalsIgnoreCase("username")) {
                            
                            tmpusername=String.valueOf(field.get(dto));
                            job.set(DBConfiguration.USERNAME_PROPERTY, tmpusername);
                        }
                    }
                    else {
                        field.setAccessible(false);
                        continue;
                    }
                } catch (IllegalArgumentException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                } catch (IllegalAccessException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
                
                field.setAccessible(false);
                
            }
            try {
                Class.forName(job.get(DBConfiguration.DRIVER_CLASS_PROPERTY));
            } catch (ClassNotFoundException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
            Connection connection = null;
            Statement statement = null;
            ResultSet rs = null;
            try {
                connection = DriverManager.getConnection(job.get(DBConfiguration.URL_PROPERTY),
                          job.get(DBConfiguration.USERNAME_PROPERTY),
                          job.get(DBConfiguration.PASSWORD_PROPERTY));
                statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY,
                        ResultSet.CONCUR_READ_ONLY);
                
                String subSql = "select count(*) as count " +  tmpsql.substring(tmpsql.indexOf("from"));
                // statement.setFetchSize(Integer.MIN_VALUE);
                
                rs = statement.executeQuery(subSql);
                rs.next();
                if(rs.wasNull()){
                    
                    continue;
                }
                
                tmpend += Long.valueOf(rs.getString("count"));
                statement.close();
                rs.close();
                connection.close();
                
            } catch (SQLException e) {
                e.printStackTrace();
            }
            
            split = new DBInputSplit(tmpstart, tmpend);
            split.setUrl(tmpurl);
            split.setSql(tmpsql);
            split.setPasswd(tmppasswd);
            split.setDriverClass(tmpdriverclass);
            split.setUserName(tmpusername);
            splits[i] = split;
        }
        return splits;
    }


    public DBInputSplit(long start, long end) {
            this.start = start;
            this.end = end;
        }

        /** {@inheritDoc} */
        public String[] getLocations() throws IOException {
            // TODO Add a layer to enable SQL "sharding" and support locality
            return new String[] {};
        }

        public String getUrl() {
            return url;
        }

        public void setUrl(String url) {
            this.url = url;
        }

        public String getSql() {
            return sql;
        }

        public void setSql(String sql) {
            this.sql = sql;
        }

        public String getUserName() {
            return userName;
        }

        public void setUserName(String userName) {
            this.userName = userName;
        }

        public String getPasswd() {
            return passwd;
        }

        public void setPasswd(String passwd) {
            this.passwd = passwd;
        }

        public String getDriverClass() {
            return driverClass;
        }

        public void setDriverClass(String driverClass) {
            this.driverClass = driverClass;
        }

        /**
         * @return The index of the first row to select
         */
        public long getStart() {
            return start;
        }

        /**
         * @return The index of the last row to select
         */
        public long getEnd() {
            return end;
        }

        /**
         * @return The total row count in this split
         */
        public long getLength() throws IOException {
            return end - start;
        }

        /** {@inheritDoc} */
        public void readFields(DataInput input) throws IOException {
            this.url=Text.readString(input);
            this.sql=Text.readString(input);
            this.userName=Text.readString(input);
            this.passwd=Text.readString(input);
            this.driverClass=Text.readString(input);
            start = input.readLong();
            end = input.readLong();
        }

        /** {@inheritDoc} */
        public void write(DataOutput output) throws IOException {
            Text.writeString(output, this.url);
            Text.writeString(output, this.sql);
            Text.writeString(output, this.userName);
            Text.writeString(output, this.passwd);
            Text.writeString(output, this.driverClass);
            output.writeLong(start);
            output.writeLong(end);
        }
    }

protected DBRecordReader(DBInputSplit split, Class<T> inputClass,
                JobConf job) throws SQLException {
            this.inputClass = inputClass;
            this.split = split;
            this.job = job;

            try {
                Class.forName(split.getDriverClass());
            } catch (ClassNotFoundException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
            myConnection = DriverManager.getConnection(split.getUrl(),
                    split.getUserName(), split.getPasswd());

            statement = myConnection.createStatement(ResultSet.TYPE_FORWARD_ONLY,
                    ResultSet.CONCUR_READ_ONLY);

            // statement.setFetchSize(Integer.MIN_VALUE);
            results = statement.executeQuery(split.getSql());
        }



0 0
原创粉丝点击