@@ -22,7 +22,8 @@ import org.apache.spark.sql.jdbc.JdbcDialects
22
22
import org .apache .spark .sql .execution .datasources .jdbc .JdbcUtils .{createConnectionFactory , getSchema , schemaString }
23
23
import com .microsoft .sqlserver .jdbc .{SQLServerBulkCopy , SQLServerBulkCopyOptions }
24
24
25
- import scala .collection .mutable .ListBuffer
25
+ import scala .collection .mutable .ArrayBuffer
26
+ import scala .util .control .Breaks .{breakable ,break }
26
27
27
28
/**
28
29
* BulkCopyUtils Object implements common utility function used by both datapool and
@@ -179,47 +180,6 @@ object BulkCopyUtils extends Logging {
179
180
conn.createStatement.executeQuery(queryStr)
180
181
}
181
182
182
- /**
183
- * getComputedCols
184
- * utility function to get computed columns.
185
- * Use computed column names to exclude computed column when matching schema.
186
- */
187
- private [spark] def getComputedCols (
188
- conn : Connection ,
189
- table : String ): List [String ] = {
190
- val queryStr = s " SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID(' ${table}'); "
191
- val computedColRs = conn.createStatement.executeQuery(queryStr)
192
- val computedCols = ListBuffer [String ]()
193
- while (computedColRs.next()) {
194
- val colName = computedColRs.getString(" name" )
195
- computedCols.append(colName)
196
- }
197
- computedCols.toList
198
- }
199
-
200
- /**
201
- * dfComputedColCount
202
- * utility function to get number of computed columns in dataframe.
203
- * Use number of computed columns in dataframe to get number of non computed column in df,
204
- * and compare with the number of non computed column in sql table
205
- */
206
- private [spark] def dfComputedColCount (
207
- dfColNames : List [String ],
208
- computedCols : List [String ],
209
- dfColCaseMap : Map [String , String ],
210
- isCaseSensitive : Boolean ): Int = {
211
- var dfComputedColCt = 0
212
- for (j <- 0 to computedCols.length- 1 ){
213
- if (isCaseSensitive && dfColNames.contains(computedCols(j)) ||
214
- ! isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase())
215
- && dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) {
216
- dfComputedColCt += 1
217
- }
218
- }
219
- dfComputedColCt
220
- }
221
-
222
-
223
183
/**
224
184
* getColMetadataMap
225
185
* Utility function convert result set meta data to array.
@@ -303,37 +263,32 @@ object BulkCopyUtils extends Logging {
303
263
val dfCols = df.schema
304
264
305
265
val tableCols = getSchema(rs, JdbcDialects .get(url))
306
- val computedCols = getComputedCols(conn, dbtable)
307
-
308
266
val prefix = " Spark Dataframe and SQL Server table have differing"
309
267
310
- if (computedCols.length == 0 ) {
311
- assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
312
- s " ${prefix} numbers of columns " )
313
- } else if (strictSchemaCheck) {
314
- val dfColNames = df.schema.fieldNames.toList
315
- val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
316
- // if df has computed column(s), check column length using non computed column in df and table.
317
- // non computed column number in df: dfCols.length - dfComputedColCt
318
- // non computed column number in table: tableCols.length - computedCols.length
319
- assertIfCheckEnabled(dfCols.length- dfComputedColCt == tableCols.length- computedCols.length, strictSchemaCheck,
320
- s " ${prefix} numbers of columns " )
321
- }
322
-
268
+ assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
269
+ s " ${prefix} numbers of columns " )
323
270
324
- val result = new Array [ColumnMetadata ](tableCols.length - computedCols.length)
325
- var nonAutoColIndex = 0
271
+ val result = new ArrayBuffer [ColumnMetadata ]()
326
272
327
273
for (i <- 0 to tableCols.length- 1 ) {
328
- val tableColName = tableCols(i).name
329
- var dfFieldIndex = - 1
330
- // set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
331
- if (computedCols.contains(tableColName)) {
332
- logDebug(s " skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex" )
333
- }else {
274
+ breakable {
275
+ val tableColName = tableCols(i).name
276
+ var dfFieldIndex = 0
334
277
var dfColName : String = " "
335
278
if (isCaseSensitive) {
336
- dfFieldIndex = dfCols.fieldIndex(tableColName)
279
+ // skip mapping / metadata if table col not in df col (strictSchema check disabled)
280
+ logDebug(s " df contains ${tableColName}: ${dfCols.fieldNames.contains(tableColName)}" )
281
+ if (! strictSchemaCheck && ! dfCols.fieldNames.contains(tableColName)) {
282
+ logDebug(s " skipping index $i sql col name $tableColName dfFieldIndex $dfFieldIndex" )
283
+ break
284
+ }
285
+ try {
286
+ dfFieldIndex = dfCols.fieldIndex(tableColName)
287
+ } catch {
288
+ case ex : IllegalArgumentException => {
289
+ throw new SQLException (s " SQL table column ${tableColName} not exist in df columns " )
290
+ }
291
+ }
337
292
dfColName = dfCols(dfFieldIndex).name
338
293
assertIfCheckEnabled(
339
294
tableColName == dfColName, strictSchemaCheck,
@@ -362,28 +317,29 @@ object BulkCopyUtils extends Logging {
362
317
dfCols(dfFieldIndex).dataType == tableCols(i).dataType,
363
318
strictSchemaCheck,
364
319
s " ${prefix} column data types at column index ${i}. " +
365
- s " DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " +
366
- s " Table col ${tableColName} dataType ${tableCols(i).dataType} " )
320
+ s " DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " +
321
+ s " Table col ${tableColName} dataType ${tableCols(i).dataType} " )
367
322
}
368
323
assertIfCheckEnabled(
369
324
dfCols(dfFieldIndex).nullable == tableCols(i).nullable,
370
325
strictSchemaCheck,
371
326
s " ${prefix} column nullable configurations at column index ${i}" +
372
- s " DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
373
- s " Table col ${tableColName} nullable config is ${tableCols(i).nullable}" )
327
+ s " DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
328
+ s " Table col ${tableColName} nullable config is ${tableCols(i).nullable}" )
374
329
375
- // Schema check passed for element, Create ColMetaData only for non auto generated column
376
- result(nonAutoColIndex) = new ColumnMetadata (
330
+ // Schema check passed for element, Create ColMetaData
331
+ result + = new ColumnMetadata (
377
332
rs.getMetaData().getColumnName(i+ 1 ),
378
333
rs.getMetaData().getColumnType(i+ 1 ),
379
334
rs.getMetaData().getPrecision(i+ 1 ),
380
335
rs.getMetaData().getScale(i+ 1 ),
381
336
dfFieldIndex
382
337
)
383
- nonAutoColIndex += 1
338
+ logDebug( s " one col metadata name: ${rs.getMetaData().getColumnName(i + 1 )} " )
384
339
}
385
340
}
386
- result
341
+ logDebug(s " metadata: ${result.toArray}, column length: ${result.length}" )
342
+ result.toArray
387
343
}
388
344
389
345
/**
0 commit comments