Skip to content
This repository was archived by the owner on Feb 27, 2025. It is now read-only.

Commit 539b561

Browse files
committed
test1
1 parent e328e20 commit 539b561

File tree

1 file changed

+30
-74
lines changed

1 file changed

+30
-74
lines changed

src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala

Lines changed: 30 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import org.apache.spark.sql.jdbc.JdbcDialects
2222
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.{createConnectionFactory, getSchema, schemaString}
2323
import com.microsoft.sqlserver.jdbc.{SQLServerBulkCopy, SQLServerBulkCopyOptions}
2424

25-
import scala.collection.mutable.ListBuffer
25+
import scala.collection.mutable.ArrayBuffer
26+
import scala.util.control.Breaks.{breakable,break}
2627

2728
/**
2829
* BulkCopyUtils Object implements common utility function used by both datapool and
@@ -179,47 +180,6 @@ object BulkCopyUtils extends Logging {
179180
conn.createStatement.executeQuery(queryStr)
180181
}
181182

182-
/**
183-
* getComputedCols
184-
* utility function to get computed columns.
185-
* Use computed column names to exclude computed column when matching schema.
186-
*/
187-
private[spark] def getComputedCols(
188-
conn: Connection,
189-
table: String): List[String] = {
190-
val queryStr = s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');"
191-
val computedColRs = conn.createStatement.executeQuery(queryStr)
192-
val computedCols = ListBuffer[String]()
193-
while (computedColRs.next()) {
194-
val colName = computedColRs.getString("name")
195-
computedCols.append(colName)
196-
}
197-
computedCols.toList
198-
}
199-
200-
/**
201-
* dfComputedColCount
202-
* utility function to get number of computed columns in dataframe.
203-
* Use number of computed columns in dataframe to get number of non computed column in df,
204-
* and compare with the number of non computed column in sql table
205-
*/
206-
private[spark] def dfComputedColCount(
207-
dfColNames: List[String],
208-
computedCols: List[String],
209-
dfColCaseMap: Map[String, String],
210-
isCaseSensitive: Boolean): Int ={
211-
var dfComputedColCt = 0
212-
for (j <- 0 to computedCols.length-1){
213-
if (isCaseSensitive && dfColNames.contains(computedCols(j)) ||
214-
!isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase())
215-
&& dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) {
216-
dfComputedColCt += 1
217-
}
218-
}
219-
dfComputedColCt
220-
}
221-
222-
223183
/**
224184
* getColMetadataMap
225185
* Utility function convert result set meta data to array.
@@ -303,37 +263,32 @@ object BulkCopyUtils extends Logging {
303263
val dfCols = df.schema
304264

305265
val tableCols = getSchema(rs, JdbcDialects.get(url))
306-
val computedCols = getComputedCols(conn, dbtable)
307-
308266
val prefix = "Spark Dataframe and SQL Server table have differing"
309267

310-
if (computedCols.length == 0) {
311-
assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
312-
s"${prefix} numbers of columns")
313-
} else if (strictSchemaCheck) {
314-
val dfColNames = df.schema.fieldNames.toList
315-
val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
316-
// if df has computed column(s), check column length using non computed column in df and table.
317-
// non computed column number in df: dfCols.length - dfComputedColCt
318-
// non computed column number in table: tableCols.length - computedCols.length
319-
assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-computedCols.length, strictSchemaCheck,
320-
s"${prefix} numbers of columns")
321-
}
322-
268+
assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
269+
s"${prefix} numbers of columns")
323270

324-
val result = new Array[ColumnMetadata](tableCols.length - computedCols.length)
325-
var nonAutoColIndex = 0
271+
val result = new ArrayBuffer[ColumnMetadata]()
326272

327273
for (i <- 0 to tableCols.length-1) {
328-
val tableColName = tableCols(i).name
329-
var dfFieldIndex = -1
330-
// set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
331-
if (computedCols.contains(tableColName)) {
332-
logDebug(s"skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex")
333-
}else{
274+
breakable {
275+
val tableColName = tableCols(i).name
276+
var dfFieldIndex = 0
334277
var dfColName:String = ""
335278
if (isCaseSensitive) {
336-
dfFieldIndex = dfCols.fieldIndex(tableColName)
279+
// skip mapping / metadata if table col not in df col (strictSchema check disabled)
280+
logDebug(s"df contains ${tableColName}: ${dfCols.fieldNames.contains(tableColName)}")
281+
if (!strictSchemaCheck && !dfCols.fieldNames.contains(tableColName)) {
282+
logDebug(s"skipping index $i sql col name $tableColName dfFieldIndex $dfFieldIndex")
283+
break
284+
}
285+
try {
286+
dfFieldIndex = dfCols.fieldIndex(tableColName)
287+
} catch {
288+
case ex: IllegalArgumentException => {
289+
throw new SQLException(s"SQL table column ${tableColName} not exist in df columns")
290+
}
291+
}
337292
dfColName = dfCols(dfFieldIndex).name
338293
assertIfCheckEnabled(
339294
tableColName == dfColName, strictSchemaCheck,
@@ -362,28 +317,29 @@ object BulkCopyUtils extends Logging {
362317
dfCols(dfFieldIndex).dataType == tableCols(i).dataType,
363318
strictSchemaCheck,
364319
s"${prefix} column data types at column index ${i}." +
365-
s" DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " +
366-
s" Table col ${tableColName} dataType ${tableCols(i).dataType} ")
320+
s" DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " +
321+
s" Table col ${tableColName} dataType ${tableCols(i).dataType} ")
367322
}
368323
assertIfCheckEnabled(
369324
dfCols(dfFieldIndex).nullable == tableCols(i).nullable,
370325
strictSchemaCheck,
371326
s"${prefix} column nullable configurations at column index ${i}" +
372-
s" DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
373-
s" Table col ${tableColName} nullable config is ${tableCols(i).nullable}")
327+
s" DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
328+
s" Table col ${tableColName} nullable config is ${tableCols(i).nullable}")
374329

375-
// Schema check passed for element, Create ColMetaData only for non auto generated column
376-
result(nonAutoColIndex) = new ColumnMetadata(
330+
// Schema check passed for element, Create ColMetaData
331+
result += new ColumnMetadata(
377332
rs.getMetaData().getColumnName(i+1),
378333
rs.getMetaData().getColumnType(i+1),
379334
rs.getMetaData().getPrecision(i+1),
380335
rs.getMetaData().getScale(i+1),
381336
dfFieldIndex
382337
)
383-
nonAutoColIndex += 1
338+
logDebug(s"one col metadata name: ${rs.getMetaData().getColumnName(i+1)}")
384339
}
385340
}
386-
result
341+
logDebug(s"metadata: ${result.toArray}, column length: ${result.length}")
342+
result.toArray
387343
}
388344

389345
/**

0 commit comments

Comments
 (0)