`
iffiffj
  • 浏览: 156510 次
  • 性别: Icon_minigender_1
  • 来自: 广州
社区版块
存档分类
最新评论

DBUnit的NoPrimaryKey的解决

阅读更多
DBUnit中的表如果没有主键那可能会出现NoPrimaryKey的异常,无法导入数据。这种现象特别是在Hibernate/JPA中的多对多关联产生的中间表中出现比较多。
解决的办法是重写DatabaseConfig.PROPERTY_METADATA_HANDLER中的getPrimaryKeys方法,当遇到没有主键的表,就把所有表字段都当成主键。在提供IDatabaseTester是就要把该Handler设置进去。
package org.iata.ios.test.utils;

import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

import org.dbunit.IDatabaseTester;
import org.dbunit.JdbcDatabaseTester;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.DefaultMetadataHandler;
import org.dbunit.database.IDatabaseConnection;
import org.h2.tools.SimpleResultSet;
import org.h2.tools.SimpleRowSource;

public class DBUnitHelper {

    private static IDatabaseTester tester;

    private DBUnitHelper() {

    }

     private static IDatabaseTester newDatabaseTester() throws Exception {
        JdbcDatabaseTester jdbcDatabaseTester = new JdbcDatabaseTester(
                Configuration.getValue("database.connection.driver_class"),
                Configuration.getValue("database.connection.url"),
                Configuration.getValue("database.connection.username"),
                Configuration.getValue("database.connection.password")) {
            public IDatabaseConnection getConnection() throws Exception {
                IDatabaseConnection connection = super.getConnection();
                connection.getConfig().setProperty(DatabaseConfig.PROPERTY_METADATA_HANDLER,
                        new MyDefaultMetadataHandler());
                return connection;
            }

        };
        return jdbcDatabaseTester;
    }

    private static IDatabaseTester getDatabaseTester() throws Exception {
        if (tester == null) {
            tester = newDatabaseTester();
        }
        return tester;
    }

    public static void executeBeforeOperations(List<DataSetOperation> list) throws Exception {
        final IDatabaseTester databaseTester = getDatabaseTester();
        if (databaseTester != null) {
            for (DataSetOperation dso : list) {
                databaseTester.setSetUpOperation(dso.getOperation());
                databaseTester.setDataSet(dso.getDataSet());
                databaseTester.onSetup();
            }
        }
    }

    public static void executeAfterOperations(List<DataSetOperation> list) throws Exception {
        final IDatabaseTester databaseTester = getDatabaseTester();
        if (databaseTester != null) {
            for (DataSetOperation dso : list) {
                databaseTester.setTearDownOperation(dso.getOperation());
                databaseTester.setDataSet(dso.getDataSet());
                databaseTester.onTearDown();
            }
        }
    }

    static class MyDefaultMetadataHandler extends DefaultMetadataHandler {
        public ResultSet getPrimaryKeys(DatabaseMetaData metaData, String schemaName, String tableName)
            throws SQLException {
            ResultSet resultSet = super.getPrimaryKeys(metaData, schemaName, tableName);
            if (resultSet.next()) {
                resultSet.close();
                resultSet = super.getPrimaryKeys(metaData, schemaName, tableName);
            } else {
                resultSet.close();
                ResultSet pkRS = super.getColumns(metaData, schemaName, tableName);
                List<Object[]> list = new ArrayList<Object[]>();
                SimpleResultSet simpleResultSet = new SimpleResultSet(new MySimpleRowSource(list));
                int i = 1;
                boolean isInit = false;
                try {
                    while (pkRS.next()) {
                        if (!isInit) {
                            ResultSetMetaData md = pkRS.getMetaData();
                            simpleResultSet.addColumn("TABLE_CAT", md.getColumnType(1), md.getPrecision(1),
                                    md.getScale(1));
                            simpleResultSet.addColumn("TABLE_SCHEM", md.getColumnType(2), md.getPrecision(2),
                                    md.getScale(2));
                            simpleResultSet.addColumn("TABLE_NAME", md.getColumnType(3), md.getPrecision(3),
                                    md.getScale(3));
                            simpleResultSet.addColumn("COLUMN_NAME", md.getColumnType(4), md.getPrecision(4),
                                    md.getScale(4));
                            simpleResultSet.addColumn("KEY_SEQ", md.getColumnType(5), md.getPrecision(5),
                                    md.getScale(5));
                            simpleResultSet.addColumn("COLUMN_NAME", md.getColumnType(4), md.getPrecision(4),
                                    md.getScale(4));
                            isInit = true;
                        }
                        Object[] objs = new Object[] { pkRS.getString(1), pkRS.getString(2), pkRS.getString(3),
                                pkRS.getString(4), i++, pkRS.getString(4) };
                        list.add(objs);
//                        String name = resultSet.getString(4);
//                        int sequence = resultSet.getInt(5);
//                        list.add(new PrimaryKeyData(name, sequence));
                    }
                } finally {
                    pkRS.close();
                }
                resultSet = simpleResultSet;
            }
            return resultSet;
        }
    }

    static class MySimpleRowSource implements SimpleRowSource {

        List<Object[]> datas = new ArrayList<Object[]>();
        int current;

        public MySimpleRowSource(List<Object[]> datas) {
            this.datas = datas;
            current = 0;
        }

        public Object[] readRow() throws SQLException {
            return datas.size() > current ? datas.get(current++) : null;
        }

        public void close() {
            datas.clear();
        }

        public void reset() throws SQLException {
            current = 0;
        }

    }
}

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics