`

mysql prepareStatement的源码实现分析

阅读更多

 

今天分析了一下mysql 5.1 版本驱动包的prepareStatement实现源码,发现驱动包并没有实现真正的服务器预编译,还是跟普通的Statement一样,在客户端拼装好完整的sql,底层还是用socket与服务器通过二进制协议流进行数据交互,然后把请求返回的结果,生成resultSet数据集合,以方便后续的数据迭代处理.

 

 

	public java.sql.ResultSet executeQuery() throws SQLException {
		checkClosed();
		
		ConnectionImpl locallyScopedConn = this.connection;
		
		checkForDml(this.originalSql, this.firstCharOfStmt);

		CachedResultSetMetaData cachedMetadata = null;

		// We need to execute this all together
		// So synchronize on the Connection's mutex (because
		// even queries going through there synchronize
		// on the same mutex.
		synchronized (locallyScopedConn.getMutex()) {
			clearWarnings();

			boolean doStreaming = createStreamingResultSet();
			
			this.batchedGeneratedKeys = null;

			// Adjust net_write_timeout to a higher value if we're
			// streaming result sets. More often than not, someone runs into
			// an issue where they blow net_write_timeout when using this
			// feature, and if they're willing to hold a result set open
			// for 30 seconds or more, one more round-trip isn't going to hurt
			//
			// This is reset by RowDataDynamic.close().
			
			if (doStreaming
					&& this.connection.getNetTimeoutForStreamingResults() > 0) {
				locallyScopedConn.execSQL(this, "SET net_write_timeout="
						+ this.connection.getNetTimeoutForStreamingResults(),
						-1, null, ResultSet.TYPE_FORWARD_ONLY,
						ResultSet.CONCUR_READ_ONLY, false, this.currentCatalog,
						null, false);
			}
			
              /*
               这里是通过方法 fillSendPacket 把相关的请求数据组装成二进制字节流,以方便后续发送给服务端,通过二进制协议,把sql进行了一层包装             
             */
			Buffer sendPacket = fillSendPacket();

			if (this.results != null) {
				if (!this.connection.getHoldResultsOpenOverStatementClose()) {
					if (!this.holdResultsOpenOverClose) {
						this.results.realClose(false);
					}
				}
			}

			String oldCatalog = null;

			if (!locallyScopedConn.getCatalog().equals(this.currentCatalog)) {
				oldCatalog = locallyScopedConn.getCatalog();
				locallyScopedConn.setCatalog(this.currentCatalog);
			}

			//
			// Check if we have cached metadata for this query...
			//
			if (locallyScopedConn.getCacheResultSetMetadata()) {
				cachedMetadata = locallyScopedConn.getCachedMetaData(this.originalSql);
			}

			Field[] metadataFromCache = null;
			
			if (cachedMetadata != null) {
				metadataFromCache = cachedMetadata.fields;
			}
			
			if (locallyScopedConn.useMaxRows()) {
				// If there isn't a limit clause in the SQL
				// then limit the number of rows to return in
				// an efficient manner. Only do this if
				// setMaxRows() hasn't been used on any Statements
				// generated from the current Connection (saves
				// a query, and network traffic).
				if (this.hasLimitClause) {
					this.results = executeInternal(this.maxRows, sendPacket,
							createStreamingResultSet(), true,
							metadataFromCache, false);
				} else {
					if (this.maxRows <= 0) {
						executeSimpleNonQuery(locallyScopedConn,
								"SET OPTION SQL_SELECT_LIMIT=DEFAULT");
					} else {
						executeSimpleNonQuery(locallyScopedConn,
								"SET OPTION SQL_SELECT_LIMIT=" + this.maxRows);
					}

					this.results = executeInternal(-1, sendPacket,
							doStreaming, true,
							metadataFromCache, false);

					if (oldCatalog != null) {
						this.connection.setCatalog(oldCatalog);
					}
				}
			} else {
				this.results = executeInternal(-1, sendPacket,
						doStreaming, true,
						metadataFromCache, false);
			}

			if (oldCatalog != null) {
				locallyScopedConn.setCatalog(oldCatalog);
			}
			
			if (cachedMetadata != null) {
				locallyScopedConn.initializeResultsMetadataFromCache(this.originalSql,
						cachedMetadata, this.results);
			} else {
				if (locallyScopedConn.getCacheResultSetMetadata()) {
					locallyScopedConn.initializeResultsMetadataFromCache(this.originalSql,
							null /* will be created */, this.results);
				}
			}
		}

		this.lastInsertId = this.results.getUpdateID();

		return this.results;
	}

 

 

 

下面我们再来一齐分析看看,fillsendPacket如何进行请求数据的包装.

 

 

	/*

	这里的 parameterValues parameterStreams存放了对应的sql与数据值.
	        
	parameterValues 就是存放着下述的sql,不过它会根据 ? 号把sql分拆存到parameterValues二维数组去.	
		sql = "select * from userjf where aid = ?";
                PreparedStatement ps = conn.prepareStatement(sql,ResultSet.TYPE_FORWARD_ONLY,  
                        ResultSet.CONCUR_READ_ONLY);

	parameterStreams 存放着下述的数据值,通过二维数组对应存放
	ps.setString(1, "param1');
	ps.setString(2, "param2');
		
	*/

	protected Buffer fillSendPacket() throws SQLException {
		return fillSendPacket(this.parameterValues, this.parameterStreams,
				this.isStream, this.streamLengths);
	}

 

 

 

	protected Buffer fillSendPacket(byte[][] batchedParameterStrings,
			InputStream[] batchedParameterStreams, boolean[] batchedIsStream,
			int[] batchedStreamLengths) throws SQLException {
		/*
		 这里是获取数据发送缓冲区
		*/
		Buffer sendPacket = this.connection.getIO().getSharedSendPacket();
		/*
		缓冲区的事前清理
		*/
		sendPacket.clear();
		/*
		下面开始组装二进制协议流,第一字节写入了3,表示这是查询协议
		*/
		sendPacket.writeByte((byte) MysqlDefs.QUERY);

		boolean useStreamLengths = this.connection
				.getUseStreamLengthsInPrepStmts();

		//
		// Try and get this allocation as close as possible
		// for BLOBs
		//
		int ensurePacketSize = 0;

		String statementComment = this.connection.getStatementComment();
		
		byte[] commentAsBytes = null;
		
		if (statementComment != null) {
			if (this.charConverter != null) {
				commentAsBytes = this.charConverter.toBytes(statementComment);
			} else {
				commentAsBytes = StringUtils.getBytes(statementComment, this.charConverter,
						this.charEncoding, this.connection
								.getServerCharacterEncoding(), this.connection
								.parserKnowsUnicode());
			}
			
			ensurePacketSize += commentAsBytes.length;
			ensurePacketSize += 6; // for /*[space] [space]*/
		}
	
		for (int i = 0; i < batchedParameterStrings.length; i++) {
			if (batchedIsStream[i] && useStreamLengths) {
				ensurePacketSize += batchedStreamLengths[i];
			}
		}

		if (ensurePacketSize != 0) {
			sendPacket.ensureCapacity(ensurePacketSize);
		}

		if (commentAsBytes != null) {
			sendPacket.writeBytesNoNull(Constants.SLASH_STAR_SPACE_AS_BYTES);
			sendPacket.writeBytesNoNull(commentAsBytes);
			sendPacket.writeBytesNoNull(Constants.SPACE_STAR_SLASH_SPACE_AS_BYTES);
		}
		
		/*
		 这里有两个重要的数据,staticSqlStrings batchedParameterStreams,均为二维数组,分别记录的sql与对应值.
		 例如我们 sql为 select * from user where id = ? and name = ?,对应设置的值为 123, jack则staticSqlStrings batchedParameterStreams两个二维数组记录的内容为:
			staticSqlStrings: {"select * from user where id =","","and name = ",""}
			batchedParameterStreams: {123,"jack"}

		 下面接下来的处理,实际上就是把 staticSqlStrings与batchedParameterStreams的数据进行组装,组装成实际的sql,select * from user where id = 123 and name = 'jack',然后
		 再拼接到请求协议数据sendPacket缓存去.
		*/
		for (int i = 0; i < batchedParameterStrings.length; i++) {
			if ((batchedParameterStrings[i] == null)
					&& (batchedParameterStreams[i] == null)) {
				throw SQLError.createSQLException(Messages
						.getString("PreparedStatement.40") //$NON-NLS-1$
						+ (i + 1), SQLError.SQL_STATE_WRONG_NO_OF_PARAMETERS);
			}

			sendPacket.writeBytesNoNull(this.staticSqlStrings[i]);

			if (batchedIsStream[i]) {
				streamToBytes(sendPacket, batchedParameterStreams[i], true,
						batchedStreamLengths[i], useStreamLengths);
			} else {
				sendPacket.writeBytesNoNull(batchedParameterStrings[i]);
			}
		}

		sendPacket.writeBytesNoNull(this.staticSqlStrings[batchedParameterStrings.length]);
		

		return sendPacket;
	}

 

 

因此从这里我们可以发现,实际上mysql的prepareStatement,只是把 sql与对应的参数值进行了组装,变成了完整的sql,然后再进行数据请求.

针对普通的Statement,存在着依赖注入的问题.如:

登录模块
如:
User validUser = login.getUserInfo(user.getName());这里的user.getName()是前台从Textfield控件中获得得值,没有做任何处理,于是再看看getUserInfo的方法,如下:

 public User getUserInfo(String userName)
    {
        User validUser = null;
        String sql = "Select * from WEB_USER where NAME='" + userName + "'";
        Database db = null;
        ResultSet rs = null;
        try {
            db = new Database("XXX");
            rs = db.execQuery(sql);   
            if (rs != null && rs.next()) {
                 validUser = new User();
                 ....
              }
            }
         }

 我们看到从前台传过来的userName没有经过任何处理而直接凭凑的SQL语句,所以如果输入者精心构造的话,就可以突破第一个屏障,生成一个有效的用户对象,比如:

输入: cjcj' or '1'='1

这样的字符串输入到后台的SQL语句就为:
select * from web_user where name='cjcj' or '1'='1'
显然,这个rs肯定是非空的。我们完成了突破第一个屏障的任务。

 

而在prepareStatement得到了改善,它在进行参数设置时,对存在依赖注入风险的参数值进行了特殊处理,如上述输入: cjcj' or '1'='1 ,它会处理成 cjcj\' or \'1\'=\'1 ,从而最后拼装成sql时,变成了 select * from web_user where name='cjcj\' or \'1\'=\'1',下面我们抽取 prepareStatement.setString(int parameterIndex, String x)方法 一齐来看看源码实现.

	/**
	 * Set a parameter to a Java String value. The driver converts this to a SQL
	 * VARCHAR or LONGVARCHAR value (depending on the arguments size relative to
	 * the driver's limits on VARCHARs) when it sends it to the database.
	 * 
	 * @param parameterIndex
	 *            the first parameter is 1...
	 * @param x
	 *            the parameter value
	 * 
	 * @exception SQLException
	 *                if a database access error occurs
	 */
	public void setString(int parameterIndex, String x) throws SQLException {
		// if the passed string is null, then set this column to null
		if (x == null) {
			setNull(parameterIndex, Types.CHAR);
		} else {
			checkClosed();
			
			int stringLength = x.length();

			if (this.connection.isNoBackslashEscapesSet()) {
				// Scan for any nasty chars

				boolean needsHexEscape = isEscapeNeededForString(x,
						stringLength);

				if (!needsHexEscape) {
					byte[] parameterAsBytes = null;

					StringBuffer quotedString = new StringBuffer(x.length() + 2);
					quotedString.append('\'');
					quotedString.append(x);
					quotedString.append('\'');
					
					if (!this.isLoadDataQuery) {
						parameterAsBytes = StringUtils.getBytes(quotedString.toString(),
								this.charConverter, this.charEncoding,
								this.connection.getServerCharacterEncoding(),
								this.connection.parserKnowsUnicode());
					} else {
						// Send with platform character encoding
						parameterAsBytes = quotedString.toString().getBytes();
					}
					
					setInternal(parameterIndex, parameterAsBytes);
				} else {
					byte[] parameterAsBytes = null;

					if (!this.isLoadDataQuery) {
						parameterAsBytes = StringUtils.getBytes(x,
								this.charConverter, this.charEncoding,
								this.connection.getServerCharacterEncoding(),
								this.connection.parserKnowsUnicode());
					} else {
						// Send with platform character encoding
						parameterAsBytes = x.getBytes();
					}
					
					setBytes(parameterIndex, parameterAsBytes);
				}

				return;
			}

			String parameterAsString = x;
			boolean needsQuoted = true;
			
			if (this.isLoadDataQuery || isEscapeNeededForString(x, stringLength)) {
				needsQuoted = false; // saves an allocation later
				
				/*
				1 创建缓冲区 buf
				2 缓存区开始位置放入 ' 号(你懂的)
				3 把字符串类型的参数值X,一个字符一个字符的进行判断处理,若为特殊符号则进行转义处理,如遇到 ' 号,则转义为 \' .
				4 缓存区最后位置放入  ' 号(你懂的)

				  若x值为 cjcj' or '1'='1 ,则最后缓存区内容为 'cjcj\' or \'1\'=\'1'
				*/
				StringBuffer buf = new StringBuffer((int) (x.length() * 1.1));
				buf.append('\'');
	
				for (int i = 0; i < stringLength; ++i) {
					char c = x.charAt(i);
	
					switch (c) {
					case 0: /* Must be escaped for 'mysql' */
						buf.append('\\');
						buf.append('0');
	
						break;
	
					case '\n': /* Must be escaped for logs */
						buf.append('\\');
						buf.append('n');
	
						break;
	
					case '\r':
						buf.append('\\');
						buf.append('r');
	
						break;
	
					case '\\':
						buf.append('\\');
						buf.append('\\');
	
						break;
	
					case '\'':
						buf.append('\\');
						buf.append('\'');
	
						break;
	
					case '"': /* Better safe than sorry */
						if (this.usingAnsiMode) {
							buf.append('\\');
						}
	
						buf.append('"');
	
						break;
	
					case '\032': /* This gives problems on Win32 */
						buf.append('\\');
						buf.append('Z');
	
						break;
	
					default:
						buf.append(c);
					}
				}
	
				buf.append('\'');
	
				parameterAsString = buf.toString();
			}

			byte[] parameterAsBytes = null;

			if (!this.isLoadDataQuery) {
				if (needsQuoted) {
					parameterAsBytes = StringUtils.getBytesWrapped(parameterAsString,
						'\'', '\'', this.charConverter, this.charEncoding, this.connection
								.getServerCharacterEncoding(), this.connection
								.parserKnowsUnicode());
				} else {
					parameterAsBytes = StringUtils.getBytes(parameterAsString,
							this.charConverter, this.charEncoding, this.connection
									.getServerCharacterEncoding(), this.connection
									.parserKnowsUnicode());
				}
			} else {
				// Send with platform character encoding
				parameterAsBytes = parameterAsString.getBytes();
			}

			setInternal(parameterIndex, parameterAsBytes);
			
			this.parameterTypes[parameterIndex - 1 + getParameterIndexOffset()] = Types.VARCHAR;
		}
	}

 

 

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics